提交 68ab1ef4 编写于 作者: L liaogang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into cpu_mem

group: deprecated-2017Q2
language: cpp language: cpp
cache: cache:
directories: directories:
......
...@@ -93,6 +93,7 @@ include(external/openblas) # download, build, install openblas ...@@ -93,6 +93,7 @@ include(external/openblas) # download, build, install openblas
include(external/swig) # download, build, install swig include(external/swig) # download, build, install swig
include(external/warpctc) # download, build, install warpctc include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any include(external/any) # download libn::any
include(external/eigen) # download eigen3
include(generic) # simplify cmake module include(generic) # simplify cmake module
include(package) # set paddle packages include(package) # set paddle packages
......
INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3)
ExternalProject_Add(
eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
URL_MD5 "1a47e78efe365a97de0c022d127607c3"
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
LIST(APPEND external_project_dependencies eigen3)
...@@ -77,6 +77,15 @@ ...@@ -77,6 +77,15 @@
# /cmake/external/*.cmake: # /cmake/external/*.cmake:
# #
# cc_test(example_test SRCS example_test.cc DEPS example glog gflags) # cc_test(example_test SRCS example_test.cc DEPS example glog gflags)
#
# To build a go static library using Golang, use the go_ prefixed version:
#
# go_library(example STATIC)
#
# To build a go shared library using Golang, use the go_ prefixed version:
#
# go_library(example SHARED)
#
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
...@@ -246,42 +255,53 @@ endfunction(nv_test) ...@@ -246,42 +255,53 @@ endfunction(nv_test)
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go") set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH}) file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle/Paddle")
# Because api.go defines a GO wrapper to ops and tensor, it depends on
# both. This implies that if any of tensor.{h,cc}, ops.{h,cu}, or
# api.go is changed, api need to be re-built.
# go_library(api
# SRCS
# api.go
# DEPS
# tensor # Because ops depend on tensor, this line is optional.
# ops)
function(go_library TARGET_NAME) function(go_library TARGET_NAME)
set(options OPTIONAL) set(options STATIC static SHARED shared)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs DEPS)
cmake_parse_arguments(go_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (${go_library_OPTIONAL} STREQUAL "SHARED")
if (go_library_SHARED OR go_library_shared)
set(BUILD_MODE "-buildmode=c-shared") set(BUILD_MODE "-buildmode=c-shared")
if(APPLE) set(LIB_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}")
set(LIB_NAME "lib${TARGET_NAME}.dylib")
else()
set(LIB_NAME "lib${TARGET_NAME}.so")
endif()
else() else()
set(BUILD_MODE "-buildmode=c-archive") set(BUILD_MODE "-buildmode=c-archive")
set(LIB_NAME "lib${TARGET_NAME}.a") set(LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}")
endif() endif()
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
# Add dummy code to support `make target_name` under Terminal Command
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
if (go_library_SHARED OR go_library_shared)
add_library(${TARGET_NAME} SHARED ${dummyfile})
else()
add_library(${TARGET_NAME} STATIC ${dummyfile})
endif()
if(go_library_DEPS)
add_dependencies(${TARGET_NAME} ${go_library_DEPS})
endif(go_library_DEPS)
# we need to symlink Paddle directory into GOPATH. If we
# don't do it and we have code that depends on Paddle, go
# get ./... will download a new Paddle repo from Github,
# without the changes in our current Paddle repo that we
# want to build.
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
# Symlink Paddle directory into GOPATH
COMMAND mkdir -p ${PADDLE_IN_GOPATH}
COMMAND rm -rf ${PADDLE_IN_GOPATH}
COMMAND ln -sf ${CMAKE_SOURCE_DIR} ${PADDLE_IN_GOPATH}
# Automatically get all dependencies specified in the source code
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ./...
# Golang build source code
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${go_library_SRCS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME}_lib ALL DEPENDS ${TARGET_NAME}_timestamp ${go_library_DEPS})
add_library(${TARGET_NAME} STATIC IMPORTED)
set_property(TARGET ${TARGET_NAME} PROPERTY
IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}")
add_dependencies(${TARGET_NAME} ${TARGET_NAME}_lib)
endfunction(go_library) endfunction(go_library)
function(go_binary TARGET_NAME) function(go_binary TARGET_NAME)
...@@ -311,10 +331,3 @@ function(go_test TARGET_NAME) ...@@ -311,10 +331,3 @@ function(go_test TARGET_NAME)
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
# go_extern will download extern go project.
# go_extern(target_name extern_source)
# go_extern(go_redis github.com/hoisie/redis)
function(go_extern TARGET_NAME)
add_custom_target(${TARGET_NAME} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN})
endfunction(go_extern)
...@@ -33,6 +33,7 @@ ELSE(WIN32) ...@@ -33,6 +33,7 @@ ELSE(WIN32)
SET(CMAKE_OSX_DEPLOYMENT_TARGET ${MACOS_VERSION} CACHE STRING SET(CMAKE_OSX_DEPLOYMENT_TARGET ${MACOS_VERSION} CACHE STRING
"Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value.") "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value.")
ENDIF() ENDIF()
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
ELSE(APPLE) ELSE(APPLE)
IF(EXISTS "/etc/issue") IF(EXISTS "/etc/issue")
......
# Design of Scope in Paddle
## Overview
Scope is an important concept in programming languages, which defines a program region that a set of bindings between names and entities applies. In a specific scope, a valid name is uniquely associated with an entity, such as a variable. And in another scope, this name may refer to other entity or nothing at all. It clearly restricts the visibility and validity of names in a program. Hence **Scope** is introduced to PaddlePaddle to manage variables in context. But different from the original abstract concept, Scope now becomes an object with two important attributes:
- Scope is an association of a name to variable.
- Variables in a parent scope can be retrieved from local scope.
A detailed explanation of these two attributes goes as following.
## Scope is an association of a name to variable.
Scope is an association of a name to variable. All variables belong to `Scope`. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can run in different scopes and update different variable in the scope.
1. Scope only contains a map of a name to variable.
All parameters, data, states in a Net should be variables and stored inside a scope. Each op should get inputs and outputs to do computation from a scope, such as data buffer, state(momentum) etc.
1. Variable can only be created by Scope and a variable can only be got from Scope. User cannot create or get a variable outside a scope. This is a constraints of our framework, and will keep our framework simple and clear.
1. Scope only contains methods that are used to Create and Get Variables. Scope do not contain Operators and have no information to run them.
`Net` is designed to drive the computation and Scope only contains a map of variables. There is no computation logic inside a `Scope`. Scope just handles the lifetime management of variables.
- `Create` is used to create a Variable by its name and add the mapping relation.
- `Get` is used to find a Variable by name.
1. Every variable only belongs to one certain Scope.
Variable can not belong to many scopes. If you want to use variables from parent scope, you can use `parent scope`.
1. Scope should destruct all Variables inside it when itself is destructed. User can never store `Variable` pointer somewhere else.
Because Variable can only be got from Scope. When destroying Scope, we also need to destroy all the Variables in it. If user store `Variable` pointer to private data member or some global variable, the pointer will be a invalid pointer when associated `Scope` is destroyed.
```cpp
class Scope {
public:
Variable* CreateVariable(const std::string& name);
const Variable* GetVariable(const std::string& name) const;
private:
std::unordered_map<std::string, std::unique_ptr<Vairable>> vars_;
};
```
## Parent scope and local scope
Just like [scope](https://en.wikipedia.org/wiki/Scope_(computer_science)) in programming languages, `Scope` in the neural network can also be a local scope. There are two attributes about local scope.
1. We can create local variables in a local scope. When that local scope are destroyed, all local variables should also be destroyed.
2. Variables in a parent scope can be retrieved from local scopes of that parent scope, i.e., when user get a variable from a scope, it will try to search this variable in current scope. If there is no such variable in the local scope, `scope` will keep searching from its parent, until the variable is found or there is no parent.
```cpp
class Scope {
public:
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const {
Variable* var = GetVarLocally(name);
if (var != nullptr) {
return var;
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}
private:
std::shared_ptr<Scope> parent_ {nullptr};
};
```
In `Scope` class, there is a private data member called `parent_`. `parent_` is a smart pointer to its parent scope. When user `Get` a variable by its `name`, the `name` will be searched inside the current scope. If the variable cannot be found locally and parent scope is not a `nullptr`, the variable will be searched inside that parent scope. `parent_` pointer's default value is `nullptr`. It means that the scope is a global scope when `parent_` is nullptr.
A local scope is very useful when we implement Recurrent Neural Network. Each timestep of an RNN should be a `Net`. Each `Net` of timestep (`StepNet` for short) should use an independent local scope. Just like variables in a while loop is inside a local scope in programming languages. By using a single `StepNet` and changing local scope, we can implement an RNN easily.
# Interface Design
```cpp
class Variable {
private:
Variable() = default;
friend class Scope;
};
class Scope {
private:
Scope(const std::shared_ptr<Scope>& parent = nullptr);
public:
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);
// return nullptr if not found.
Variable* GetVariable(const std::string& name) const;
// return Error if already contains same name variable.
Error CreateVariable(const std::string& name);
private:
std::shared_ptr<Scope> parent_;
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
};
```
## Only scope can create a variable
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`.
## When scope destroyed, all variables inside this scope should be destroyed together
The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
## Sharing a parent scope
Local scope contains a `parent_` pointer. It is a linked-list for scopes. Using a `shared_ptr` because when a local scope is using, its parents cannot be destroyed.
Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shared pointer. We cannot construct a scope variable, because it cannot be passed to other scope as `parent` pointer.
## Orthogonal interface
`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily.
...@@ -111,7 +111,7 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和 ...@@ -111,7 +111,7 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和
# define training dataset reader # define training dataset reader
def train_reader(): def train_reader():
train_x = np.array([[1, 1], [1, 2], [3, 4], [5, 2]]) train_x = np.array([[1, 1], [1, 2], [3, 4], [5, 2]])
train_y = np.array([-2, -3, -7, -7]) train_y = np.array([[-2], [-3], [-7], [-7]])
def reader(): def reader():
for i in xrange(train_y.shape[0]): for i in xrange(train_y.shape[0]):
yield train_x[i], train_y[i] yield train_x[i], train_y[i]
......
if(NOT CMAKE_Go_COMPILER)
if(NOT $ENV{GO_COMPILER} STREQUAL "")
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)
if(CMAKE_Go_FLAGS_ENV_INIT)
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
endif()
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
endif()
endif()
set(Go_BIN_PATH
$ENV{GOPATH}
$ENV{GOROOT}
$ENV{GOROOT}/../bin
$ENV{GO_COMPILER}
/usr/bin
/usr/local/bin
)
if(CMAKE_Go_COMPILER_INIT)
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
else()
find_program(CMAKE_Go_COMPILER
NAMES go
PATHS ${Go_BIN_PATH}
)
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
message("-- The Golang compiler identification is ${VERSION}")
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
endif()
endif()
mark_as_advanced(CMAKE_Go_COMPILER)
configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
set(CMAKE_Go_COMPILER_LOADED 1)
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
set(CMAKE_Go_LINKER_PREFERENCE 40)
set(CMAKE_Go_OUTPUT_EXTENSION .o)
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
if(NOT CMAKE_Go_COMPILE_OBJECT)
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
endif()
if(NOT CMAKE_Go_LINK_EXECUTABLE)
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
endif()
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")
# Setting Paddle Compile Flags
include(CheckCXXCompilerFlag)
include(CheckCCompilerFlag)
include(CheckCXXSymbolExists)
include(CheckTypeSize)
function(CheckCompilerCXX11Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
# https://gist.github.com/yamaya/2924292
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
endif()
else()
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
endif()
endif()
endif()
endfunction()
CheckCompilerCXX11Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
# Common gpu architectures: Kepler, Maxwell
foreach(capability 30 35 50)
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
endforeach()
if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
endif()
# Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
endif()
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle")
file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH})
function(GO_LIBRARY NAME BUILD_TYPE)
if(BUILD_TYPE STREQUAL "STATIC")
set(BUILD_MODE -buildmode=c-archive)
set(LIB_NAME "lib${NAME}.a")
else()
set(BUILD_MODE -buildmode=c-shared)
if(APPLE)
set(LIB_NAME "lib${NAME}.dylib")
else()
set(LIB_NAME "lib${NAME}.so")
endif()
endif()
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
# find Paddle directory.
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY)
# automatically get all dependencies specified in the source code
# for given target.
add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
# make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle
# in github.
add_custom_target(${NAME}_copyPaddle
COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
add_dependencies(${NAME}_goGet ${NAME}_copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} ${NAME}_goGet)
endfunction(GO_LIBRARY)
...@@ -30,7 +30,13 @@ func main() { ...@@ -30,7 +30,13 @@ func main() {
log.SetLevel(level) log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout)) timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout) e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register()
if err != nil {
panic(err)
}
s, err := pserver.NewService(idx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -13,10 +13,13 @@ typedef int paddle_master_client; ...@@ -13,10 +13,13 @@ typedef int paddle_master_client;
import "C" import "C"
import ( import (
"strings"
"sync" "sync"
"time"
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client {
return h return h
} }
type addresser string //export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
func (a addresser) Address() string { p := C.GoString(etcdEndpoints)
return string(a) cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize) ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
......
...@@ -2,18 +2,12 @@ package master ...@@ -2,18 +2,12 @@ package master
import ( import (
"os" "os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
...@@ -24,11 +18,11 @@ type Client struct { ...@@ -24,11 +18,11 @@ type Client struct {
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addr Addresser, bufSize int) *Client { func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan []byte, bufSize) c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr) go c.monitorMaster(addrCh)
go c.getRecords() go c.getRecords()
return c return c
} }
...@@ -72,12 +66,10 @@ func (c *Client) getRecords() { ...@@ -72,12 +66,10 @@ func (c *Client) getRecords() {
} }
} }
func (c *Client) monitorMaster(addr Addresser) { func (c *Client) monitorMaster(addrCh <-chan string) {
lastMaster := "" lastMaster := ""
monitor := func() { for curMaster := range addrCh {
// get the lastest address of the master server,
// connect to the new address once address changed. // connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster { if curMaster != lastMaster {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
...@@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) { ...@@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) {
// to retry next time. // to retry next time.
curMaster = lastMaster curMaster = lastMaster
} }
} }
} }
lastMaster = curMaster lastMaster = curMaster
} }
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset set dataset for the master server to dispatch.
......
...@@ -26,12 +26,6 @@ func init() { ...@@ -26,12 +26,6 @@ func init() {
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
} }
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) { func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0" const path = "/tmp/master_client_test_0"
...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) { ...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if err != nil {
...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) { ...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) {
// Manually intialize client to avoid calling c.getRecords() // Manually intialize client to avoid calling c.getRecords()
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
......
...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) { ...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) {
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
total = 50 total = 50
) )
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
panic(err) panic(err)
...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) { ...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil { if err != nil {
...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) { ...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) {
} }
w.Close() w.Close()
f.Close() f.Close()
curAddr := make(chan string, 1)
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ { for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
......
...@@ -18,8 +18,8 @@ const ( ...@@ -18,8 +18,8 @@ const (
DefaultAddrPath = "/master/addr" DefaultAddrPath = "/master/addr"
) )
// EtcdClient is the etcd client that master uses for fault tolerance // EtcdClient is the etcd client that the master uses for fault
// and service registry. // tolerance and service registry.
type EtcdClient struct { type EtcdClient struct {
lockPath string lockPath string
statePath string statePath string
...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) {
state := kvs[0].Value state := kvs[0].Value
return state, nil return state, nil
} }
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
valChan <- string(ev.Kv.Value)
}
}
}
cmake_minimum_required(VERSION 3.0)
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
project(cxx_go C Go)
include(golang)
include(flags)
go_library(paddle_pserver_cclient STATIC) go_library(paddle_pserver_cclient STATIC)
add_subdirectory(test) add_subdirectory(test)
cmake_minimum_required(VERSION 3.0)
add_executable(main main.c) cc_library(main SRCS main.c DEPS paddle_pserver_cclient)
add_dependencies(main paddle_pserver_cclient) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)
add_executable(test_cclient test_cclient.c)
add_dependencies(test_cclient paddle_pserver_cclient)
if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
set(CMAKE_EXE_LINKER_FLAGS "-pthread")
endif()
if(PROJ_ROOT)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/..)
target_link_libraries(main ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
else(PROJ_ROOT)
include_directories(${CMAKE_BINARY_DIR})
target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
endif(PROJ_ROOT)
package pserver package pserver
import ( import (
"errors"
"hash/fnv" "hash/fnv"
"sort" "sort"
"time" "time"
...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error { ...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating // SendGrads sends gradients to parameter servers for updating
// parameters. // parameters.
func (c *Client) SendGrads(grads []Gradient) error { func (c *Client) SendGrads(grads []Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g Gradient) {
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
...@@ -31,7 +30,7 @@ func init() { ...@@ -31,7 +30,7 @@ func init() {
port[i] = p port[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
package pserver
import (
"context"
"errors"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
numPservers int
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
}
// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
return &EtcdClient{
etcdTimeout: timeout,
numPservers: numPservers,
etcdEndpoints: endpoints,
}
}
// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {
var err error
e.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return 0, err
}
// initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: e.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout)
continue
}
e.etcdClient = cli
log.Debugf("inited client to %s", e.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
var pserverIdx int
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error
pserverIdx, err = e.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
return pserverIdx, nil
}
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
idx = i
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil {
return 0, err
}
return idx, nil
}
package pserver package pserver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -55,160 +46,25 @@ type Gradient Parameter ...@@ -55,160 +46,25 @@ type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int
mu sync.Mutex mu sync.Mutex
opt *optimizer opt *optimizer
paramMap map[string]Parameter paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) { func NewService(idx int) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)} s := &Service{
idx: idx,
opt: newOptimizer(sgd, 0.005),
}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.initDesiredPsercers(ctx, numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil return s, nil
} }
func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// InitParam initializes a parameter. // InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select { select {
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
) )
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -75,7 +75,7 @@ func TestFull(t *testing.T) { ...@@ -75,7 +75,7 @@ func TestFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
...@@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) { ...@@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) {
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
......
...@@ -9,7 +9,7 @@ add_subdirectory(pserver) ...@@ -9,7 +9,7 @@ add_subdirectory(pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
add_subdirectory(optimizer) add_subdirectory(optimizer)
add_subdirectory(strings) add_subdirectory(string)
if(Boost_FOUND) if(Boost_FOUND)
add_subdirectory(memory) add_subdirectory(memory)
......
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(enforce_test SRCS enforce_test.cc)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/string/printf.h>
#include <exception>
#include <sstream>
namespace paddle {
namespace framework {
/**
* @brief Enforce exception. Inherits std::exception
*
* All enforce condition not met, will throw an EnforceNotMet exception.
*/
class EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
std::ostringstream sout;
sout << msg << " at [" << file << ":" << fileline << "];";
all_msg_ = sout.str();
}
const char* what() const noexcept override { return all_msg_.c_str(); }
private:
std::string all_msg_;
};
// From https://stackoverflow.com/questions/30130930/
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
// should be true in most situation, it will make the compiler generate faster
// code by adding `UNLIKELY` macro.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
/**
* @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
* __LINE__
*
* This macro take __VA_ARGS__, user can pass any type if that type can
* serialize to std::ostream
*/
#define PADDLE_THROW(...) \
do { \
throw ::paddle::framework::EnforceNotMet( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
#define PADDLE_ENFORCE(condition, ...) \
do { \
if (UNLIKELY(!(condition))) { \
PADDLE_THROW(__VA_ARGS__); \
} \
} while (0)
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -9,18 +9,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,18 +9,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #include <gtest/gtest.h>
/** #include <paddle/framework/enforce.h>
* __must_check macro. It make the function's return value must be used,
* otherwise it will raise a compile warning. And also Paddle treat all compile TEST(ENFORCE, OK) {
* warnings as errors. PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
*/ size_t val = 1;
#ifdef __GNUC__ const size_t limit = 10;
#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400 PADDLE_ENFORCE(val < limit, "Enforce is OK too");
#define __must_check __attribute__((warn_unused_result)) }
#else
#define __must_check TEST(ENFORCE, FAILED) {
#endif bool in_catch = false;
#else try {
#define __must_check PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
#endif } catch (paddle::framework::EnforceNotMet err) {
in_catch = true;
std::string msg = "Enforce is not ok 123 at all";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
\ No newline at end of file
...@@ -25,21 +25,24 @@ class Variable { ...@@ -25,21 +25,24 @@ class Variable {
public: public:
template <typename T> template <typename T>
const T& Get() const { const T& Get() const {
PADDLE_ASSERT(holder_ != nullptr); PADDLE_ASSERT(IsType<T>());
PADDLE_ASSERT(std::type_index(typeid(T)) ==
std::type_index(holder_->Type()));
return *static_cast<const T*>(holder_->Ptr()); return *static_cast<const T*>(holder_->Ptr());
} }
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
if (holder_ == nullptr || if (!IsType<T>()) {
std::type_index(typeid(T)) != std::type_index(holder_->Type())) {
holder_.reset(new PlaceholderImpl<T>(new T())); holder_.reset(new PlaceholderImpl<T>(new T()));
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T*>(holder_->Ptr());
} }
template <typename T>
bool IsType() const {
return holder_ != nullptr &&
std::type_index(typeid(T)) == std::type_index(holder_->Type());
}
private: private:
struct Placeholder { struct Placeholder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
......
...@@ -601,7 +601,7 @@ void TrainerThread::backward() { ...@@ -601,7 +601,7 @@ void TrainerThread::backward() {
void TrainerThread::backwardCallback(Parameter* para) { void TrainerThread::backwardCallback(Parameter* para) {
// CPU parameters are merged in the end // CPU parameters are merged in the end
if (!para->useGpu()) return; if (!para->useGpu() || para->isStatic()) return;
int paramId = para->getID(); int paramId = para->getID();
if (multiMachine_->getNumThreads() == 1) { if (multiMachine_->getNumThreads() == 1) {
......
...@@ -191,6 +191,11 @@ void Layer::addOutputArgument(int deviceId) { ...@@ -191,6 +191,11 @@ void Layer::addOutputArgument(int deviceId) {
void Layer::copyOutputToOtherDevice() { void Layer::copyOutputToOtherDevice() {
for (size_t i = 0; i != outputOtherDevice_.size(); i++) { for (size_t i = 0; i != outputOtherDevice_.size(); i++) {
SetDevice device(outputOtherDevice_[i].deviceId); SetDevice device(outputOtherDevice_[i].deviceId);
// If outputOtherDevice_[i].value is a CpuMatrix,
// the copyFrom is a synchronous interface.
// If outputOtherDevice_[i].value is a GpuMatrix, since subsequent
// calculations are all on HPPL_STREAM_DEFAULT,
// copyFrom can be an asynchronous interface.
outputOtherDevice_[i].value->copyFrom(*getOutputValue(), outputOtherDevice_[i].value->copyFrom(*getOutputValue(),
HPPL_STREAM_DEFAULT); HPPL_STREAM_DEFAULT);
outputOtherDevice_[i].sequenceStartPositions = outputOtherDevice_[i].sequenceStartPositions =
......
...@@ -1565,6 +1565,8 @@ void CpuMatrix::copyFrom(const Matrix& src, hl_stream_t stream) { ...@@ -1565,6 +1565,8 @@ void CpuMatrix::copyFrom(const Matrix& src, hl_stream_t stream) {
const_cast<real*>(src.getData()), const_cast<real*>(src.getData()),
sizeof(real) * elementCnt_, sizeof(real) * elementCnt_,
stream); stream);
// There is a need to add synchronization to ensure that the data is copied.
hl_stream_synchronize(stream);
} else if (typeid(src) == typeid(CpuMatrix)) { } else if (typeid(src) == typeid(CpuMatrix)) {
memcpy(data_, src.getData(), sizeof(real) * elementCnt_); memcpy(data_, src.getData(), sizeof(real) * elementCnt_);
} else { } else {
......
...@@ -239,7 +239,8 @@ public: ...@@ -239,7 +239,8 @@ public:
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
// asynchronous copy // For GpuMatrix this is an asynchronous copy interface
// For CpuMatrix this is an synchronous copy interface
virtual void copyFrom(const Matrix& src, hl_stream_t stream) { virtual void copyFrom(const Matrix& src, hl_stream_t stream) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
......
...@@ -657,6 +657,8 @@ void CpuVectorT<T>::copyFrom(const VectorT<T>& src, hl_stream_t stream) { ...@@ -657,6 +657,8 @@ void CpuVectorT<T>::copyFrom(const VectorT<T>& src, hl_stream_t stream) {
(void*)src.getData(), (void*)src.getData(),
sizeof(T) * this->getSize(), sizeof(T) * this->getSize(),
stream); stream);
// There is a need to add synchronization to ensure that the data is copied.
hl_stream_synchronize(stream);
} else { } else {
src.copyTo(this); src.copyTo(this);
} }
......
...@@ -168,11 +168,11 @@ public: ...@@ -168,11 +168,11 @@ public:
virtual void copyFrom(const VectorT<T>& src) = 0; virtual void copyFrom(const VectorT<T>& src) = 0;
/** /**
* If use_gpu, this function will push the copy-task to the specifed-stream * If GpuVector, this function is an asynchronous interface,
* and return immediately. * will push the copy-task to the specifed-stream and return immediately.
* *
* If not use GPU, this function is same as * If CpuVector, this function is an synchronous interface,
* the copyFrom(const VectorT<T>& src), which use stream HPPL_STREAM_DEFAULT. * same as the copyFrom(const VectorT<T>& src).
*/ */
virtual void copyFrom(const VectorT<T>& src, hl_stream_t stream) = 0; virtual void copyFrom(const VectorT<T>& src, hl_stream_t stream) = 0;
......
...@@ -1127,4 +1127,18 @@ TEST(Matrix, MaxOutFwdBwd) { ...@@ -1127,4 +1127,18 @@ TEST(Matrix, MaxOutFwdBwd) {
} }
} }
TEST(CpuMatrix, copyFrom) {
const size_t height = 1000;
const size_t width = 1000;
CpuMatrix cpu(height, width);
GpuMatrix gpu(height, width);
CpuMatrix copy(height, width);
cpu.randomizeUniform();
gpu.copyFrom(cpu);
copy.copyFrom(gpu, HPPL_STREAM_DEFAULT);
TensorCheckEqual(cpu, copy);
}
#endif #endif
...@@ -6,4 +6,3 @@ nv_test(cuda_test SRCS cuda_test.cu) ...@@ -6,4 +6,3 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_test(must_check_test SRCS must_check_test.cc)
#include <gtest/gtest.h>
#include <paddle/platform/must_check.h>
int __must_check SomeFunctionMustCheck() { return 0; }
TEST(MustCheck, all) {
// This line should not be compiled, because the
// return value of SomeFunctionMustCheck marked as __must_check
// SomeFunctionMustCheck();
}
\ No newline at end of file
...@@ -109,6 +109,10 @@ class DenseScanner(IScanner): ...@@ -109,6 +109,10 @@ class DenseScanner(IScanner):
if len(self.__shape__) > 3: if len(self.__shape__) > 3:
raise ValueError( raise ValueError(
"The dimension of input cannot be greater than 3.") "The dimension of input cannot be greater than 3.")
if len(self.__shape__) == 0:
raise ValueError(
"The input should be a vector, please check your input data."
)
self.__dim__ = reduce(lambda x, y: x * y, self.__shape__) self.__dim__ = reduce(lambda x, y: x * y, self.__shape__)
if len(self.__shape__) == 1 and self.__dim__ != self.input_type.dim: if len(self.__shape__) == 1 and self.__dim__ != self.input_type.dim:
raise ValueError( raise ValueError(
...@@ -140,7 +144,7 @@ class DenseScanner(IScanner): ...@@ -140,7 +144,7 @@ class DenseScanner(IScanner):
if len(self.__shape__) > 1: if len(self.__shape__) > 1:
# The last-two dimenstions are the frame height and width. # The last-two dimenstions are the frame height and width.
# For example, the layout is CHW for 3-D feature of image. # For example, the layout is CHW for 3-D feature of image.
# The H and W are the fram height and width. # The H and W are the frame height and width.
h, w = self.__shape__[-2:] h, w = self.__shape__[-2:]
argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameHeight(self.pos, h)
argument.setSlotFrameWidth(self.pos, w) argument.setSlotFrameWidth(self.pos, w)
......
...@@ -31,6 +31,7 @@ Configuring cmake in /paddle/build ... ...@@ -31,6 +31,7 @@ Configuring cmake in /paddle/build ...
-DWITH_DOC=OFF -DWITH_DOC=OFF
-DWITH_GPU=${WITH_GPU:-OFF} -DWITH_GPU=${WITH_GPU:-OFF}
-DWITH_AVX=${WITH_AVX:-OFF} -DWITH_AVX=${WITH_AVX:-OFF}
-DWITH_GOLANG=${WITH_GOLANG:-OFF}
-DWITH_SWIG_PY=ON -DWITH_SWIG_PY=ON
-DCUDNN_ROOT=/usr/ -DCUDNN_ROOT=/usr/
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
...@@ -43,6 +44,7 @@ cmake .. \ ...@@ -43,6 +44,7 @@ cmake .. \
-DWITH_DOC=OFF \ -DWITH_DOC=OFF \
-DWITH_GPU=${WITH_GPU:-OFF} \ -DWITH_GPU=${WITH_GPU:-OFF} \
-DWITH_AVX=${WITH_AVX:-OFF} \ -DWITH_AVX=${WITH_AVX:-OFF} \
-DWITH_GOLANG=${WITH_GOLANG:-OFF} \
-DWITH_SWIG_PY=ON \ -DWITH_SWIG_PY=ON \
-DCUDNN_ROOT=/usr/ \ -DCUDNN_ROOT=/usr/ \
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} \ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} \
......
cc_library(stringpiece SRCS piece.cc)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/strings/stringpiece.h" #include "paddle/string/piece.h"
#include <string.h> #include <string.h>
...@@ -23,29 +23,25 @@ ...@@ -23,29 +23,25 @@
#include <stdexcept> #include <stdexcept>
namespace paddle { namespace paddle {
namespace string {
StringPiece::StringPiece() : data_(NULL), size_(0) {} Piece::Piece() : data_(NULL), size_(0) {}
StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) { Piece::Piece(const char* d, size_t n) : data_(d), size_(n) {
if (d == NULL && n != 0) if (d == NULL && n != 0)
throw std::invalid_argument( throw std::invalid_argument("Piece requires len to be 0 for NULL data");
"StringPiece requires len to be 0 for NULL data");
} }
StringPiece::StringPiece(const char* s) : data_(s) { Piece::Piece(const char* s) : data_(s) { size_ = (s == NULL) ? 0 : strlen(s); }
size_ = (s == NULL) ? 0 : strlen(s);
}
StringPiece::StringPiece(const std::string& s) Piece::Piece(const std::string& s) : data_(s.data()), size_(s.size()) {}
: data_(s.data()), size_(s.size()) {}
char StringPiece::operator[](size_t n) const { char Piece::operator[](size_t n) const {
if (n >= len()) if (n >= len()) throw std::invalid_argument("index out of Piece length");
throw std::invalid_argument("index out of StringPiece length");
return data_[n]; return data_[n];
} }
int Compare(StringPiece a, StringPiece b) { int Compare(Piece a, Piece b) {
const size_t min_len = (a.len() < b.len()) ? a.len() : b.len(); const size_t min_len = (a.len() < b.len()) ? a.len() : b.len();
int r = memcmp(a.data(), b.data(), min_len); int r = memcmp(a.data(), b.data(), min_len);
if (r == 0) { if (r == 0) {
...@@ -57,85 +53,86 @@ int Compare(StringPiece a, StringPiece b) { ...@@ -57,85 +53,86 @@ int Compare(StringPiece a, StringPiece b) {
return r; return r;
} }
bool operator==(StringPiece x, StringPiece y) { bool operator==(Piece x, Piece y) {
return ((x.len() == y.len()) && return ((x.len() == y.len()) &&
(x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0)); (x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0));
} }
bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } bool operator!=(Piece x, Piece y) { return !(x == y); }
bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; } bool operator<(Piece x, Piece y) { return Compare(x, y) < 0; }
bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; } bool operator>(Piece x, Piece y) { return Compare(x, y) > 0; }
bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; } bool operator<=(Piece x, Piece y) { return Compare(x, y) <= 0; }
bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; } bool operator>=(Piece x, Piece y) { return Compare(x, y) >= 0; }
bool HasPrefix(StringPiece s, StringPiece x) { bool HasPrefix(Piece s, Piece x) {
return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0)); return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0));
} }
bool HasSuffix(StringPiece s, StringPiece x) { bool HasSuffix(Piece s, Piece x) {
return ((s.len() >= x.len()) && return ((s.len() >= x.len()) &&
(memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0)); (memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0));
} }
StringPiece SkipPrefix(StringPiece s, size_t n) { Piece SkipPrefix(Piece s, size_t n) {
if (n > s.len()) if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length"); throw std::invalid_argument("Skip distance larger than Piece length");
return StringPiece(s.data() + n, s.len() - n); return Piece(s.data() + n, s.len() - n);
} }
StringPiece SkipSuffix(StringPiece s, size_t n) { Piece SkipSuffix(Piece s, size_t n) {
if (n > s.len()) if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length"); throw std::invalid_argument("Skip distance larger than Piece length");
return StringPiece(s.data(), s.len() - n); return Piece(s.data(), s.len() - n);
} }
StringPiece TrimPrefix(StringPiece s, StringPiece x) { Piece TrimPrefix(Piece s, Piece x) {
return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s; return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s;
} }
StringPiece TrimSuffix(StringPiece s, StringPiece x) { Piece TrimSuffix(Piece s, Piece x) {
return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s; return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s;
} }
bool Contains(StringPiece s, StringPiece sub) { bool Contains(Piece s, Piece sub) {
return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end(); return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end();
} }
size_t Index(StringPiece s, StringPiece sub) { size_t Index(Piece s, Piece sub) {
auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end()); auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end());
return e != s.end() ? e - s.data() : StringPiece::npos; return e != s.end() ? e - s.data() : Piece::npos;
} }
size_t Find(StringPiece s, char c, size_t pos) { size_t Find(Piece s, char c, size_t pos) {
if (pos >= s.len()) { if (pos >= s.len()) {
return StringPiece::npos; return Piece::npos;
} }
const char* result = const char* result =
reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos)); reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos));
return result != nullptr ? result - s.data() : StringPiece::npos; return result != nullptr ? result - s.data() : Piece::npos;
} }
size_t RFind(StringPiece s, char c, size_t pos) { size_t RFind(Piece s, char c, size_t pos) {
if (s.len() == 0) return StringPiece::npos; if (s.len() == 0) return Piece::npos;
for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data(); for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data();
p--) { p--) {
if (*p == c) { if (*p == c) {
return p - s.data(); return p - s.data();
} }
} }
return StringPiece::npos; return Piece::npos;
} }
StringPiece SubStr(StringPiece s, size_t pos, size_t n) { Piece SubStr(Piece s, size_t pos, size_t n) {
if (pos > s.len()) pos = s.len(); if (pos > s.len()) pos = s.len();
if (n > s.len() - pos) n = s.len() - pos; if (n > s.len() - pos) n = s.len() - pos;
return StringPiece(s.data() + pos, n); return Piece(s.data() + pos, n);
} }
std::ostream& operator<<(std::ostream& o, StringPiece piece) { std::ostream& operator<<(std::ostream& o, Piece piece) {
return o << piece.ToString(); return o << piece.ToString();
} }
} // namespace string
} // namespace paddle } // namespace paddle
...@@ -20,33 +20,34 @@ ...@@ -20,33 +20,34 @@
#include <string> #include <string>
namespace paddle { namespace paddle {
namespace string {
// StringPiece points into a std::string object but doesn't own the // Piece points into a std::string object but doesn't own the
// string. It is for efficient access to strings. Like Go's string // string. It is for efficient access to strings. Like Go's string
// type. Not that StringPiece doesn't mutate the underlying string, // type. Not that Piece doesn't mutate the underlying string,
// so it is thread-safe given that the underlying string doesn't // so it is thread-safe given that the underlying string doesn't
// change. Because StringPiece contains a little data members, and // change. Because Piece contains a little data members, and
// its syntax is simple as it doesn't own/manage the string, it is // its syntax is simple as it doesn't own/manage the string, it is
// cheap to construct StringPieces and pass them around. // cheap to construct Pieces and pass them around.
class StringPiece { class Piece {
public: public:
static const size_t npos = static_cast<size_t>(-1); static const size_t npos = static_cast<size_t>(-1);
// We provide non-explicit singleton constructors so users can // We provide non-explicit singleton constructors so users can
// pass in a "const char*" or a "string" wherever a "StringPiece" // pass in a "const char*" or a "string" wherever a "Piece"
// is expected. These contructors ensure that if data_ is NULL, // is expected. These contructors ensure that if data_ is NULL,
// size_ is 0. // size_ is 0.
StringPiece(); Piece();
StringPiece(const char* d, size_t n); Piece(const char* d, size_t n);
StringPiece(const char* d); Piece(const char* d);
StringPiece(const std::string& s); Piece(const std::string& s);
const char* data() const { return data_; } const char* data() const { return data_; }
size_t len() const { return size_; } size_t len() const { return size_; }
char operator[](size_t n) const; char operator[](size_t n) const;
// StringPiece doesn't own the string, so both iterator and const // Piece doesn't own the string, so both iterator and const
// iterator are const char* indeed. // iterator are const char* indeed.
typedef const char* const_iterator; typedef const char* const_iterator;
typedef const char* iterator; typedef const char* iterator;
...@@ -63,43 +64,44 @@ private: ...@@ -63,43 +64,44 @@ private:
// Intentionally copyable // Intentionally copyable
}; };
int Compare(StringPiece a, StringPiece b); int Compare(Piece a, Piece b);
bool operator==(StringPiece x, StringPiece y); bool operator==(Piece x, Piece y);
bool operator!=(StringPiece x, StringPiece y); bool operator!=(Piece x, Piece y);
bool operator<(StringPiece x, StringPiece y); bool operator<(Piece x, Piece y);
bool operator>(StringPiece x, StringPiece y); bool operator>(Piece x, Piece y);
bool operator<=(StringPiece x, StringPiece y); bool operator<=(Piece x, Piece y);
bool operator>=(StringPiece x, StringPiece y); bool operator>=(Piece x, Piece y);
bool HasPrefix(StringPiece s, StringPiece prefix); bool HasPrefix(Piece s, Piece prefix);
bool HasSuffix(StringPiece s, StringPiece suffix); bool HasSuffix(Piece s, Piece suffix);
StringPiece SkipPrefix(StringPiece s, size_t n); Piece SkipPrefix(Piece s, size_t n);
StringPiece SkipSuffix(StringPiece s, size_t n); Piece SkipSuffix(Piece s, size_t n);
// Skip the prefix (or suffix) if it matches with the string. // Skip the prefix (or suffix) if it matches with the string.
StringPiece TrimPrefix(StringPiece s, StringPiece prefix); Piece TrimPrefix(Piece s, Piece prefix);
StringPiece TrimSuffix(StringPiece s, StringPiece suffix); Piece TrimSuffix(Piece s, Piece suffix);
// Returns if s contains sub. Any s except for empty s contains an // Returns if s contains sub. Any s except for empty s contains an
// empty sub. // empty sub.
bool Contains(StringPiece s, StringPiece sub); bool Contains(Piece s, Piece sub);
// Return the first occurrence of sub in s, or npos. If both s and // Return the first occurrence of sub in s, or npos. If both s and
// sub is empty, it returns npos; otherwise, if only sub is empty, it // sub is empty, it returns npos; otherwise, if only sub is empty, it
// returns 0. // returns 0.
size_t Index(StringPiece s, StringPiece sub); size_t Index(Piece s, Piece sub);
// Return the first occurrence of c in s[pos:end], or npos. // Return the first occurrence of c in s[pos:end], or npos.
size_t Find(StringPiece s, char c, size_t pos); size_t Find(Piece s, char c, size_t pos);
// Search range is [0..pos] inclusive. If pos == npos, search everything. // Search range is [0..pos] inclusive. If pos == npos, search everything.
size_t RFind(StringPiece s, char c, size_t pos); size_t RFind(Piece s, char c, size_t pos);
StringPiece SubStr(StringPiece s, size_t pos, size_t n); Piece SubStr(Piece s, size_t pos, size_t n);
// allow StringPiece to be logged // allow Piece to be logged
std::ostream& operator<<(std::ostream& o, StringPiece piece); std::ostream& operator<<(std::ostream& o, Piece piece);
} // namespace string
} // namespace paddle } // namespace paddle
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/strings/stringpiece.h" #include "paddle/string/piece.h"
#include <sstream> #include <sstream>
...@@ -22,42 +22,44 @@ ...@@ -22,42 +22,44 @@
TEST(StringPiece, Construct) { TEST(StringPiece, Construct) {
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(NULL, s.data()); EXPECT_EQ(NULL, s.data());
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
{ EXPECT_THROW(paddle::StringPiece s(NULL, 10000U), std::invalid_argument); }
{ {
paddle::StringPiece s(NULL); EXPECT_THROW(paddle::string::Piece s(NULL, 10000U), std::invalid_argument);
}
{
paddle::string::Piece s(NULL);
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
{ {
std::string a; std::string a;
EXPECT_EQ(0U, a.size()); EXPECT_EQ(0U, a.size());
paddle::StringPiece s(a); paddle::string::Piece s(a);
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
} }
TEST(StringPiece, CopyAndAssign) { TEST(StringPiece, CopyAndAssign) {
paddle::StringPiece empty; paddle::string::Piece empty;
EXPECT_EQ(0U, empty.len()); EXPECT_EQ(0U, empty.len());
paddle::StringPiece a("hello"); paddle::string::Piece a("hello");
paddle::StringPiece b = a; paddle::string::Piece b = a;
EXPECT_EQ(b.len(), strlen("hello")); EXPECT_EQ(b.len(), strlen("hello"));
EXPECT_EQ(a, b); EXPECT_EQ(a, b);
std::string storage("hello"); std::string storage("hello");
paddle::StringPiece c(storage); paddle::string::Piece c(storage);
EXPECT_EQ(a, c); EXPECT_EQ(a, c);
EXPECT_NE(a.data(), c.data()); EXPECT_NE(a.data(), c.data());
} }
TEST(StringPiece, Compare) { TEST(StringPiece, Compare) {
{ {
paddle::StringPiece a("hello"); paddle::string::Piece a("hello");
paddle::StringPiece b("world"); paddle::string::Piece b("world");
EXPECT_TRUE(a != b); EXPECT_TRUE(a != b);
EXPECT_FALSE(a == b); EXPECT_FALSE(a == b);
EXPECT_TRUE(a < b); EXPECT_TRUE(a < b);
...@@ -68,7 +70,7 @@ TEST(StringPiece, Compare) { ...@@ -68,7 +70,7 @@ TEST(StringPiece, Compare) {
EXPECT_GT(Compare(b, a), 0); EXPECT_GT(Compare(b, a), 0);
} }
{ {
paddle::StringPiece a, b; paddle::string::Piece a, b;
EXPECT_TRUE(a == b); EXPECT_TRUE(a == b);
EXPECT_FALSE(a != b); EXPECT_FALSE(a != b);
EXPECT_FALSE(a < b); EXPECT_FALSE(a < b);
...@@ -82,31 +84,31 @@ TEST(StringPiece, Compare) { ...@@ -82,31 +84,31 @@ TEST(StringPiece, Compare) {
TEST(StringPiece, ToString) { TEST(StringPiece, ToString) {
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(std::string(""), s.ToString()); EXPECT_EQ(std::string(""), s.ToString());
} }
{ {
paddle::StringPiece s(NULL); paddle::string::Piece s(NULL);
EXPECT_EQ(std::string(""), s.ToString()); EXPECT_EQ(std::string(""), s.ToString());
} }
{ {
paddle::StringPiece s("hello"); paddle::string::Piece s("hello");
EXPECT_EQ(std::string("hello"), s.ToString()); EXPECT_EQ(std::string("hello"), s.ToString());
} }
} }
TEST(StringPiece, HasPrefixSuffix) { TEST(StringPiece, HasPrefixSuffix) {
using paddle::HasPrefix; using paddle::string::HasPrefix;
using paddle::HasSuffix; using paddle::string::HasSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_FALSE(HasPrefix(s, "something")); EXPECT_FALSE(HasPrefix(s, "something"));
EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_FALSE(HasSuffix(s, "something")); EXPECT_FALSE(HasSuffix(s, "something"));
EXPECT_TRUE(HasSuffix(s, "")); EXPECT_TRUE(HasSuffix(s, ""));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_TRUE(HasPrefix(s, "a")); EXPECT_TRUE(HasPrefix(s, "a"));
EXPECT_TRUE(HasPrefix(s, "ap")); EXPECT_TRUE(HasPrefix(s, "ap"));
...@@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) { ...@@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) {
} }
TEST(StringPiece, SkipPrefixSuffix) { TEST(StringPiece, SkipPrefixSuffix) {
using paddle::SkipPrefix; using paddle::string::SkipPrefix;
using paddle::SkipSuffix; using paddle::string::SkipSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", SkipPrefix(s, 0)); EXPECT_EQ("", SkipPrefix(s, 0));
EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument); EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument);
...@@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) { ...@@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) {
EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument); EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument);
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("app", SkipPrefix(s, 0)); EXPECT_EQ("app", SkipPrefix(s, 0));
EXPECT_EQ("pp", SkipPrefix(s, 1)); EXPECT_EQ("pp", SkipPrefix(s, 1));
EXPECT_EQ("p", SkipPrefix(s, 2)); EXPECT_EQ("p", SkipPrefix(s, 2));
...@@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) { ...@@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) {
} }
TEST(StringPiece, TrimPrefixSuffix) { TEST(StringPiece, TrimPrefixSuffix) {
using paddle::TrimPrefix; using paddle::string::TrimPrefix;
using paddle::TrimSuffix; using paddle::string::TrimSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", TrimPrefix(s, "")); EXPECT_EQ("", TrimPrefix(s, ""));
EXPECT_EQ("", TrimPrefix(s, "something")); EXPECT_EQ("", TrimPrefix(s, "something"));
...@@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) { ...@@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) {
EXPECT_EQ("", TrimSuffix(s, "something")); EXPECT_EQ("", TrimSuffix(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("app", TrimPrefix(s, "")); EXPECT_EQ("app", TrimPrefix(s, ""));
EXPECT_EQ("pp", TrimPrefix(s, "a")); EXPECT_EQ("pp", TrimPrefix(s, "a"));
EXPECT_EQ("p", TrimPrefix(s, "ap")); EXPECT_EQ("p", TrimPrefix(s, "ap"));
...@@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) { ...@@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) {
} }
TEST(StringPiece, Contains) { TEST(StringPiece, Contains) {
using paddle::Contains; using paddle::string::Contains;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_FALSE(Contains(s, "")); EXPECT_FALSE(Contains(s, ""));
EXPECT_FALSE(Contains(s, "something")); EXPECT_FALSE(Contains(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_TRUE(Contains(s, "")); EXPECT_TRUE(Contains(s, ""));
EXPECT_TRUE(Contains(s, "a")); EXPECT_TRUE(Contains(s, "a"));
EXPECT_TRUE(Contains(s, "p")); EXPECT_TRUE(Contains(s, "p"));
...@@ -193,15 +195,15 @@ TEST(StringPiece, Contains) { ...@@ -193,15 +195,15 @@ TEST(StringPiece, Contains) {
} }
TEST(StringPiece, Index) { TEST(StringPiece, Index) {
using paddle::Index; using paddle::string::Index;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, Index(s, "")); EXPECT_EQ(npos, Index(s, ""));
EXPECT_EQ(npos, Index(s, "something")); EXPECT_EQ(npos, Index(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(0U, Index(s, "")); EXPECT_EQ(0U, Index(s, ""));
EXPECT_EQ(0U, Index(s, "a")); EXPECT_EQ(0U, Index(s, "a"));
EXPECT_EQ(1U, Index(s, "p")); EXPECT_EQ(1U, Index(s, "p"));
...@@ -213,14 +215,14 @@ TEST(StringPiece, Index) { ...@@ -213,14 +215,14 @@ TEST(StringPiece, Index) {
} }
TEST(StringPiece, Find) { TEST(StringPiece, Find) {
using paddle::Find; using paddle::string::Find;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, Find(s, 'a', 0U)); EXPECT_EQ(npos, Find(s, 'a', 0U));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(0U, Find(s, 'a', 0U)); EXPECT_EQ(0U, Find(s, 'a', 0U));
EXPECT_EQ(1U, Find(s, 'p', 0U)); EXPECT_EQ(1U, Find(s, 'p', 0U));
EXPECT_EQ(1U, Find(s, 'p', 1U)); EXPECT_EQ(1U, Find(s, 'p', 1U));
...@@ -230,14 +232,14 @@ TEST(StringPiece, Find) { ...@@ -230,14 +232,14 @@ TEST(StringPiece, Find) {
} }
TEST(StringPiece, RFind) { TEST(StringPiece, RFind) {
using paddle::RFind; using paddle::string::RFind;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, RFind(s, 'a', 0U)); EXPECT_EQ(npos, RFind(s, 'a', 0U));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(2U, RFind(s, 'p', 2U)); EXPECT_EQ(2U, RFind(s, 'p', 2U));
EXPECT_EQ(0U, RFind(s, 'a', 2U)); EXPECT_EQ(0U, RFind(s, 'a', 2U));
EXPECT_EQ(1U, RFind(s, 'p', 1U)); EXPECT_EQ(1U, RFind(s, 'p', 1U));
...@@ -247,15 +249,15 @@ TEST(StringPiece, RFind) { ...@@ -247,15 +249,15 @@ TEST(StringPiece, RFind) {
} }
TEST(StringPiece, SubStr) { TEST(StringPiece, SubStr) {
using paddle::SubStr; using paddle::string::SubStr;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 0, 1)); EXPECT_EQ("", SubStr(s, 0, 1));
EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 1, 0));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 1, 0));
EXPECT_EQ("", SubStr(s, 2, 0)); EXPECT_EQ("", SubStr(s, 2, 0));
...@@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) { ...@@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) {
} }
TEST(StringPiece, StreamOutput) { TEST(StringPiece, StreamOutput) {
using paddle::StringPiece; using paddle::string::Piece;
std::stringstream o; std::stringstream o;
o << StringPiece(); o << paddle::string::Piece();
EXPECT_EQ("", o.str()); EXPECT_EQ("", o.str());
o << StringPiece("hello"); o << paddle::string::Piece("hello");
EXPECT_EQ("hello", o.str()); EXPECT_EQ("hello", o.str());
o << StringPiece(); o << paddle::string::Piece();
EXPECT_EQ("hello", o.str()); EXPECT_EQ("hello", o.str());
} }
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Compared with std::stringstream, there are primary purpose of
// string::Printf:
//
// 1. Type-safe printing, with why and how explained in
// http://www.drdobbs.com/stringprintf-a-typesafe-printf-family-fo/184401999.
// Implementation includes
//
// https://github.com/c42f/tinyformat
// boost::format
// std::stringstream
//
// std::stringstream is not convenient enough in many cases. For example:
//
// std::cout << std::setprecision(2) << std::fixed << 1.23456 << "\n";
//
// boost::format is the most convenient one. We can have
//
// std::cout << format("%2% %1%") % 36 % 77;
//
// or
//
// format fmter("%2% %1%");
// fmter % 36; fmter % 77;
// std::cout << fmter.c_str();
//
// But the overloading of % might be overkilling and it would be
// more efficient if it can write to std::cout directly.
//
// tinyformat has an interface compatible with the C-printf style,
// and it can writes to a stream or returns a std::string:
//
// std::cout << tfm::printf(
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// or
//
// tfm::format(std::cout,
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// 2. High-performance -- most printed strings are not too long and
// doens't need dynamic memory allocation. Many StringPrintf
// implementations doesn't enforce type-safe, but are
// high-performance, including
//
// https://developers.google.com/optimization/reference/base/stringprintf/
// https://github.com/adobe/chromium/blob/master/base/stringprintf.h
// https://github.com/google/protobuf/blob/master/src/google/protobuf/stubs/stringprintf.h
//
// According to
// https://github.com/c42f/tinyformat#compile-time-and-code-bloat,
// boost::format runs too slow and results in large executable binary
// files. So here we port tinyformat.
#pragma once
#include <iostream>
#include <sstream>
#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat
namespace paddle {
namespace string {
template <typename... Args>
void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
}
template <typename... Args>
std::string Sprintf(const char* fmt, const Args&... args) {
std::ostringstream oss;
Fprintf(oss, fmt, args...);
return oss.str();
}
template <typename... Args>
void Printf(const char* fmt, const Args&... args) {
Fprintf(std::cout, fmt, args...);
}
} // namespace string
} // namespace paddle
#include "paddle/string/printf.h"
#include <string>
#include "gtest/gtest.h"
TEST(StringPrintf, StringPrintf) {
std::string weekday = "Wednesday";
const char* month = "July";
size_t day = 27;
long hour = 14;
int min = 44;
EXPECT_EQ(std::string("Wednesday, July 27, 14:44"),
paddle::string::Sprintf(
"%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min));
}
此差异已折叠。
cc_library(stringpiece SRCS stringpiece.cc)
cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags)
...@@ -19,7 +19,21 @@ limitations under the License. */ ...@@ -19,7 +19,21 @@ limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/platform/must_check.h"
/**
* __must_check macro. It make the function's return value must be used,
* otherwise it will raise a compile warning. And also Paddle treat all compile
* warnings as errors.
*/
#ifdef __GNUC__
#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400
#define __must_check __attribute__((warn_unused_result))
#else
#define __must_check
#endif
#else
#define __must_check
#endif
namespace paddle { namespace paddle {
......
...@@ -1381,7 +1381,7 @@ def inputs(layers, *args): ...@@ -1381,7 +1381,7 @@ def inputs(layers, *args):
if len(args) != 0: if len(args) != 0:
layers.extend(args) layers.extend(args)
Inputs(*[l.name for l in layers]) Inputs(* [l.name for l in layers])
def outputs(layers, *args): def outputs(layers, *args):
...@@ -1424,7 +1424,7 @@ def outputs(layers, *args): ...@@ -1424,7 +1424,7 @@ def outputs(layers, *args):
assert len(layers) > 0 assert len(layers) > 0
if HasInputsSet(): # input already set if HasInputsSet(): # input already set
Outputs(*[l.name for l in layers]) Outputs(* [l.name for l in layers])
return # just return outputs. return # just return outputs.
if len(layers) != 1: if len(layers) != 1:
......
...@@ -25,8 +25,9 @@ import uci_housing ...@@ -25,8 +25,9 @@ import uci_housing
import sentiment import sentiment
import wmt14 import wmt14
import mq2007 import mq2007
import flowers
__all__ = [ __all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007' 'uci_housing', 'wmt14', 'mq2007', 'flowers'
] ]
...@@ -13,18 +13,18 @@ ...@@ -13,18 +13,18 @@
# limitations under the License. # limitations under the License.
""" """
This module will download dataset from This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators. and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories. This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category. minimum of 40 images for each category.
The database was used in: The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision, number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008) Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
""" """
...@@ -34,9 +34,9 @@ from common import download ...@@ -34,9 +34,9 @@ from common import download
import tarfile import tarfile
import scipy.io as scio import scipy.io as scio
from paddle.v2.image import * from paddle.v2.image import *
from paddle.v2.reader import *
import os import os
import numpy as np import numpy as np
import paddle.v2 as paddle
from multiprocessing import cpu_count from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid'] __all__ = ['train', 'test', 'valid']
...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' ...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
TRAIN_FLAG = 'tstid'
TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample): def default_mapper(sample):
...@@ -53,8 +59,8 @@ def default_mapper(sample): ...@@ -53,8 +59,8 @@ def default_mapper(sample):
map image bytes data to type needed by model input layer map image bytes data to type needed by model input layer
''' '''
img, label = sample img, label = sample
img = paddle.image.load_image_bytes(img) img = load_image_bytes(img)
img = paddle.image.simple_transform(img, 256, 224, True) img = simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label return img.flatten().astype('float32'), label
...@@ -63,22 +69,23 @@ def reader_creator(data_file, ...@@ -63,22 +69,23 @@ def reader_creator(data_file,
setid_file, setid_file,
dataset_name, dataset_name,
mapper=default_mapper, mapper=default_mapper,
buffered_size=1024): buffered_size=1024,
use_xmap=True):
''' '''
1. read images from tar file and 1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/ merge images into batch files in 102flowers.tgz_batch/
2. get a reader to read sample from batch file 2. get a reader to read sample from batch file
:param data_file: downloaded data file :param data_file: downloaded data file
:type data_file: string :type data_file: string
:param label_file: downloaded label file :param label_file: downloaded label file
:type label_file: string :type label_file: string
:param setid_file: downloaded setid file containing information :param setid_file: downloaded setid file containing information
about how to split dataset about how to split dataset
:type setid_file: string :type setid_file: string
:param dataset_name: data set name (tstid|trnid|valid) :param dataset_name: data set name (tstid|trnid|valid)
:type dataset_name: string :type dataset_name: string
:param mapper: a function to map image bytes data to type :param mapper: a function to map image bytes data to type
needed by model input layer needed by model input layer
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images :param buffered_size: the size of buffer used to process images
...@@ -105,15 +112,17 @@ def reader_creator(data_file, ...@@ -105,15 +112,17 @@ def reader_creator(data_file,
for sample, label in itertools.izip(data, batch['label']): for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) yield sample, int(label)
return paddle.reader.xmap_readers(mapper, reader, if use_xmap:
cpu_count(), buffered_size) return xmap_readers(mapper, reader, cpu_count(), buffered_size)
else:
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024): def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers training set reader. Create flowers training set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102] image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps: translated from original color image by steps:
1. resize to 256*256 1. resize to 256*256
2. random crop to 224*224 2. random crop to 224*224
...@@ -128,15 +137,15 @@ def train(mapper=default_mapper, buffered_size=1024): ...@@ -128,15 +137,15 @@ def train(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024): def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers test set reader. Create flowers test set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102] image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps: translated from original color image by steps:
1. resize to 256*256 1. resize to 256*256
2. random crop to 224*224 2. random crop to 224*224
...@@ -151,15 +160,15 @@ def test(mapper=default_mapper, buffered_size=1024): ...@@ -151,15 +160,15 @@ def test(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024): def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers validation set reader. Create flowers validation set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102] image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps: translated from original color image by steps:
1. resize to 256*256 1. resize to 256*256
2. random crop to 224*224 2. random crop to 224*224
...@@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024): ...@@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def fetch(): def fetch():
......
...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase): ...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase):
def test_train(self): def test_train(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train()) paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020) self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_test(self): def test_test(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test()) paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149) self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_valid(self): def test_valid(self):
......
...@@ -51,7 +51,7 @@ class Parameters(object): ...@@ -51,7 +51,7 @@ class Parameters(object):
def __init__(self): def __init__(self):
self.__param_conf__ = dict() self.__param_conf__ = dict()
self.__gradient_machines__ = [] self.__gradient_machines__ = []
self.__tmp_params__ = [] self.__tmp_params__ = dict()
def __append_config__(self, param_conf): def __append_config__(self, param_conf):
""" """
...@@ -128,13 +128,10 @@ class Parameters(object): ...@@ -128,13 +128,10 @@ class Parameters(object):
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
# create new parameter in python numpy. # create new parameter in python numpy.
if len(self.__tmp_params__) != 0: if key in self.__tmp_params__:
ret_list = [ return self.__tmp_params__[key]
mat for name, mat in self.__tmp_params__ if name == key else:
] return np.ndarray(shape=shape, dtype=np.float32)
if len(ret_list) == 1:
return ret_list[0]
return np.ndarray(shape=shape, dtype=np.float32)
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
param = __get_parameter_in_gradient_machine__( param = __get_parameter_in_gradient_machine__(
...@@ -187,7 +184,7 @@ class Parameters(object): ...@@ -187,7 +184,7 @@ class Parameters(object):
(shape, value.shape)) (shape, value.shape))
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
self.__tmp_params__.append((key, value)) self.__tmp_params__[key] = value
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
__copy_parameter_to_gradient_machine__(each_gradient_machine, __copy_parameter_to_gradient_machine__(each_gradient_machine,
...@@ -231,7 +228,7 @@ class Parameters(object): ...@@ -231,7 +228,7 @@ class Parameters(object):
raise ValueError("gradient_machine should be api.GradientMachine") raise ValueError("gradient_machine should be api.GradientMachine")
if len(self.__tmp_params__) != 0: if len(self.__tmp_params__) != 0:
for name, val in self.__tmp_params__: for name, val in self.__tmp_params__.iteritems():
try: try:
__copy_parameter_to_gradient_machine__(gradient_machine, __copy_parameter_to_gradient_machine__(gradient_machine,
name, val) name, val)
...@@ -287,6 +284,18 @@ class Parameters(object): ...@@ -287,6 +284,18 @@ class Parameters(object):
@staticmethod @staticmethod
def from_tar(f): def from_tar(f):
"""
Create a `Parameters` object from the given file. And
the `Parameters` only contains the parameters in this
file. It is adapted the parameters are same in the
defined network and the given file. For example, it
can be used in the inference.
:param f: the initialized model file.
:type f: tar file
:return: A Parameters object.
:rtype: Parameters.
"""
params = Parameters() params = Parameters()
tar = tarfile.TarFile(fileobj=f, mode='r') tar = tarfile.TarFile(fileobj=f, mode='r')
for finfo in tar: for finfo in tar:
...@@ -302,6 +311,21 @@ class Parameters(object): ...@@ -302,6 +311,21 @@ class Parameters(object):
params.deserialize(param_name, f) params.deserialize(param_name, f)
return params return params
def init_from_tar(self, f):
"""
Different from `from_tar`, this interface can be used to
init partial network parameters from another saved model.
:param f: the initialized model file.
:type f: tar file
:return: Nothing.
"""
tar_param = Parameters.from_tar(f)
for pname in tar_param.names():
if pname in self.names():
self.set(pname, tar_param.get(pname))
def __get_parameter_in_gradient_machine__(gradient_machine, name): def __get_parameter_in_gradient_machine__(gradient_machine, name):
""" """
......
...@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could be used in user ...@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could be used in user
program. program.
""" """
__all__ = ['np_array', 'text_file'] __all__ = ['np_array', 'text_file', "recordio"]
def np_array(x): def np_array(x):
...@@ -55,3 +55,24 @@ def text_file(path): ...@@ -55,3 +55,24 @@ def text_file(path):
f.close() f.close()
return reader return reader
def recordio(path):
"""
Creates a data reader that outputs record one one by one from given recordio file
:path: path of recordio file
:returns: data reader of recordio file
"""
import recordio as rec
def reader():
f = rec.reader(path)
while True:
r = f.read()
if r is None:
break
yield r
f.close()
return reader
...@@ -166,12 +166,12 @@ def buffered(reader, size): ...@@ -166,12 +166,12 @@ def buffered(reader, size):
The buffered data reader will read and save data entries into a The buffered data reader will read and save data entries into a
buffer. Reading from the buffered data reader will proceed as long buffer. Reading from the buffered data reader will proceed as long
as the buffer is not empty. as the buffer is not empty.
:param reader: the data reader to read from. :param reader: the data reader to read from.
:type reader: callable :type reader: callable
:param size: max buffer size. :param size: max buffer size.
:type size: int :type size: int
:returns: the buffered data reader. :returns: the buffered data reader.
""" """
...@@ -238,7 +238,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -238,7 +238,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
:type mapper: callable :type mapper: callable
:param reader: the data reader to read from :param reader: the data reader to read from
:type reader: callable :type reader: callable
:param process_num: process number to handle original sample :param process_num: process number to handle original sample
:type process_num: int :type process_num: int
:param buffer_size: max buffer size :param buffer_size: max buffer size
:type buffer_size: int :type buffer_size: int
...@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
:rtype: callable :rtype: callable
""" """
end = XmapEndSignal() end = XmapEndSignal()
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# define a worker to read samples from reader to in_queue # define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue): def read_worker(reader, in_queue):
...@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_order += 1 in_order += 1
in_queue.put(end) in_queue.put(end)
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# define a worker to handle samples from in_queue by mapper # define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue # and put mapped samples into out_queue
def handle_worker(in_queue, out_queue, mapper): def handle_worker(in_queue, out_queue, mapper):
...@@ -298,19 +289,27 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -298,19 +289,27 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end) in_queue.put(end)
out_queue.put(end) out_queue.put(end)
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = []
for i in xrange(process_num):
worker = Thread(target=target, args=args)
worker.daemon = True
workers.append(worker)
for w in workers:
w.start()
def xreader(): def xreader():
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = []
for i in xrange(process_num):
worker = Thread(target=target, args=args)
worker.daemon = True
workers.append(worker)
for w in workers:
w.start()
sample = out_queue.get() sample = out_queue.get()
while not isinstance(sample, XmapEndSignal): while not isinstance(sample, XmapEndSignal):
yield sample yield sample
......
...@@ -34,5 +34,14 @@ class TestTextFile(unittest.TestCase): ...@@ -34,5 +34,14 @@ class TestTextFile(unittest.TestCase):
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))
class TestRecordIO(unittest.TestCase):
def test_recordio(self):
path = os.path.join(
os.path.dirname(__file__), "test_recordio_creator.dat")
reader = paddle.v2.reader.creator.recordio(path)
for idx, r in enumerate(reader()):
self.assertSequenceEqual(r, str(idx))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -132,15 +132,17 @@ class TestXmap(unittest.TestCase): ...@@ -132,15 +132,17 @@ class TestXmap(unittest.TestCase):
for order in orders: for order in orders:
for tNum in thread_nums: for tNum in thread_nums:
for size in buffered_size: for size in buffered_size:
result = [] reader = paddle.v2.reader.xmap_readers(mapper,
for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(0), reader_creator_10(0),
tNum, size, order)(): tNum, size, order)
result.append(i) for n in xrange(3):
if not order: result = []
result.sort() for i in reader():
for idx, e in enumerate(result): result.append(i)
self.assertEqual(e, mapper(idx)) if not order:
result.sort()
for idx, e in enumerate(result):
self.assertEqual(e, mapper(idx))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,14 +20,17 @@ import cStringIO ...@@ -20,14 +20,17 @@ import cStringIO
import numpy import numpy
def __rand_param_config__(name): def __rand_param_config__(name, psize=None):
conf = ParameterConfig() conf = ParameterConfig()
conf.name = name conf.name = name
size = 1 size = 1
for i in xrange(2): if psize is None:
dim = random.randint(1, 1000) for i in xrange(2):
conf.dims.append(dim) dim = random.randint(1, 1000)
size *= dim conf.dims.append(dim)
size *= dim
else:
size = psize
conf.size = size conf.size = size
assert conf.IsInitialized() assert conf.IsInitialized()
return conf return conf
...@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase): ...@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase):
expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32) expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32)
assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6)) assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6))
def test_init_from_tar(self):
def get_param(names, size):
p = parameters.Parameters()
for k, v in zip(names, size):
p.__append_config__(__rand_param_config__(k, v))
for name in p.names():
param = p.get(name)
param[:] = numpy.random.uniform(
-1.0, 1.0, size=p.get_shape(name))
p.set(name, param)
return p
def get_parames():
name1 = ['param_0', 'param_1']
size1 = [128, 256]
p1 = get_param(name1, size1)
file1 = cStringIO.StringIO()
p1.to_tar(file1)
file1.seek(0)
name2 = ['param_0', 'param_1', 'param_2']
size2 = [128, 256, 288]
p2 = get_param(name2, size2)
file2 = cStringIO.StringIO()
p2.to_tar(file2)
file2.seek(0)
return p1, file1, p2, file2
p1, file1, p2, file2 = get_parames()
p2.init_from_tar(file1)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
p1, file1, p2, file2 = get_parames()
p1.init_from_tar(file2)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,7 +15,8 @@ setup_requires=["requests", ...@@ -15,7 +15,8 @@ setup_requires=["requests",
"protobuf==3.1", "protobuf==3.1",
"recordio", "recordio",
"matplotlib", "matplotlib",
"rarfile"] "rarfile",
"scipy>=0.19.0"]
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires+=["opencv-python"] setup_requires+=["opencv-python"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册