提交 86d184ed 编写于 作者: L Luo Tao

Merge branch 'develop' into save_log_prob

...@@ -19,3 +19,9 @@ third_party/ ...@@ -19,3 +19,9 @@ third_party/
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake
...@@ -17,14 +17,20 @@ ...@@ -17,14 +17,20 @@
- id: detect-private-key - id: detect-private-key
files: (?!.*third_party)^.*$ | (?!.*book)^.*$ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git - repo: local
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
hooks: hooks:
- id: clang-formater - id: clang-format
- repo: https://github.com/dnephin/pre-commit-golang name: clang-format
sha: e4693a4c282b4fc878eda172a929f7a6508e7d16 description: Format files with ClangFormat.
entry: clang-format -i
language: system
files: \.(c|cc|cxx|cpp|h|hpp|hxx)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks: hooks:
- id: go-fmt - id: go-fmt
files: (.*\.go) types:
- id: go-lint - go
files: (.*\.go) - id: gometalinter
types:
- go
...@@ -4,6 +4,7 @@ cache: ...@@ -4,6 +4,7 @@ cache:
- $HOME/.ccache - $HOME/.ccache
- $HOME/.cache/pip - $HOME/.cache/pip
- $TRAVIS_BUILD_DIR/build/third_party - $TRAVIS_BUILD_DIR/build/third_party
- $TRAVIS_BUILD_DIR/build_android/third_party
sudo: required sudo: required
dist: trusty dist: trusty
os: os:
...@@ -11,6 +12,7 @@ os: ...@@ -11,6 +12,7 @@ os:
env: env:
- JOB=build_doc - JOB=build_doc
- JOB=check_style - JOB=check_style
- JOB=build_android
addons: addons:
apt: apt:
packages: packages:
...@@ -39,6 +41,8 @@ before_install: ...@@ -39,6 +41,8 @@ before_install:
- pip install rarfile - pip install rarfile
- curl https://glide.sh/get | bash - curl https://glide.sh/get | bash
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)" - eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- go get -u github.com/alecthomas/gometalinter
- gometalinter --install
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License # limitations under the License
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR})
...@@ -28,13 +27,17 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -28,13 +27,17 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(Boost QUIET) if(NOT ANDROID)
find_package(Boost QUIET)
endif()
include(simd) include(simd)
################################ Configurations ####################################### ################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF)
option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
...@@ -73,6 +76,10 @@ if(ANDROID) ...@@ -73,6 +76,10 @@ if(ANDROID)
"Disable PYTHON when cross-compiling for Android" FORCE) "Disable PYTHON when cross-compiling for Android" FORCE)
set(WITH_RDMA OFF CACHE STRING set(WITH_RDMA OFF CACHE STRING
"Disable RDMA when cross-compiling for Android" FORCE) "Disable RDMA when cross-compiling for Android" FORCE)
set(WITH_MKLDNN OFF CACHE STRING
"Disable MKLDNN when cross-compiling for Android" FORCE)
set(WITH_MKLML OFF CACHE STRING
"Disable MKLML package when cross-compiling for Android" FORCE)
endif(ANDROID) endif(ANDROID)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
...@@ -86,6 +93,7 @@ endif() ...@@ -86,6 +93,7 @@ endif()
######################################################################################## ########################################################################################
include(external/mklml) # download mklml package
include(external/zlib) # download, build, install zlib include(external/zlib) # download, build, install zlib
include(external/gflags) # download, build, install gflags include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog include(external/glog) # download, build, install glog
...@@ -93,10 +101,12 @@ include(external/gtest) # download, build, install gtest ...@@ -93,10 +101,12 @@ include(external/gtest) # download, build, install gtest
include(external/protobuf) # download, build, install protobuf include(external/protobuf) # download, build, install protobuf
include(external/python) # download, build, install python include(external/python) # download, build, install python
include(external/openblas) # download, build, install openblas include(external/openblas) # download, build, install openblas
include(external/mkldnn) # download, build, install mkldnn
include(external/swig) # download, build, install swig include(external/swig) # download, build, install swig
include(external/warpctc) # download, build, install warpctc include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any include(external/any) # download libn::any
include(external/eigen) # download eigen3 include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
include(configure) # add paddle env configuration include(configure) # add paddle env configuration
...@@ -133,12 +143,21 @@ if(WITH_GPU) ...@@ -133,12 +143,21 @@ if(WITH_GPU)
endif(NOT WITH_DSO) endif(NOT WITH_DSO)
endif(WITH_GPU) endif(WITH_GPU)
if(WITH_MKLDNN)
list(APPEND EXTERNAL_LIBS ${MKLDNN_LIBRARY} ${MKLDNN_IOMP_LIB})
endif()
if(USE_NNPACK) if(USE_NNPACK)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB} "rt") include(external/nnpack)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIBS})
endif(USE_NNPACK) endif(USE_NNPACK)
add_subdirectory(proto) add_subdirectory(proto)
# "add_subdirectory(go)" should be placed after the following loine,
# because it depends on paddle/optimizer.
add_subdirectory(paddle/optimizer)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be # "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it. # placed after this block, because they depends on it.
if(WITH_GOLANG) if(WITH_GOLANG)
...@@ -146,7 +165,9 @@ if(WITH_GOLANG) ...@@ -146,7 +165,9 @@ if(WITH_GOLANG)
endif(WITH_GOLANG) endif(WITH_GOLANG)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) if(WITH_PYTHON)
add_subdirectory(python)
endif()
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
...@@ -25,9 +25,9 @@ COPY ./paddle/scripts/docker/root/ /root/ ...@@ -25,9 +25,9 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y \ apt-get install -y \
git python-pip python-dev openssh-server bison \ git python-pip python-dev openssh-server bison \
wget unzip tar xz-utils bzip2 gzip coreutils ntp \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \ curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-numpy python-matplotlib gcc g++ \ python-numpy python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \ automake locales clang-format-3.8 swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \ liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \
......
...@@ -14,6 +14,17 @@ RUN apt-get update && \ ...@@ -14,6 +14,17 @@ RUN apt-get update && \
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \ wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
apt-get clean -y apt-get clean -y
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# git credential to skip password typing # git credential to skip password typing
RUN git config --global credential.helper store RUN git config --global credential.helper store
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/develop/doc/) [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://doc.paddlepaddle.org/develop/doc/)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/doc_cn/) [![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://doc.paddlepaddle.org/develop/doc_cn/)
[![Coverage Status](https://coveralls.io/repos/github/PaddlePaddle/Paddle/badge.svg?branch=develop)](https://coveralls.io/github/PaddlePaddle/Paddle?branch=develop) [![Coverage Status](https://coveralls.io/repos/github/PaddlePaddle/Paddle/badge.svg?branch=develop)](https://coveralls.io/github/PaddlePaddle/Paddle?branch=develop)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
...@@ -61,35 +61,36 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl ...@@ -61,35 +61,36 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl
## Installation ## Installation
It is recommended to check out the It is recommended to check out the
[Docker installation guide](http://www.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html) [Docker installation guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html)
before looking into the before looking into the
[build from source guide](http://www.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html) [build from source guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html)
## Documentation ## Documentation
We provide [English](http://www.paddlepaddle.org/develop/doc/) and We provide [English](http://doc.paddlepaddle.org/develop/doc/) and
[Chinese](http://www.paddlepaddle.org/doc_cn/) documentation. [Chinese](http://doc.paddlepaddle.org/doc_cn/) documentation.
- [Deep Learning 101](http://book.paddlepaddle.org/index.html) - [Deep Learning 101](http://book.paddlepaddle.org/index.html)
You might want to start from the this online interactive book that can run in Jupyter Notebook. You might want to start from the this online interactive book that can run in Jupyter Notebook.
- [Distributed Training](http://www.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html) - [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html)
You can run distributed training jobs on MPI clusters. You can run distributed training jobs on MPI clusters.
- [Distributed Training on Kubernetes](http://www.paddlepaddle.org/develop/doc/howto/usage/k8s/k8s_en.html) - [Distributed Training on Kubernetes](http://doc.paddlepaddle.org/develop/doc/howto/usage/k8s/k8s_en.html)
You can also run distributed training jobs on Kubernetes clusters. You can also run distributed training jobs on Kubernetes clusters.
- [Python API](http://www.paddlepaddle.org/develop/doc/api/index_en.html) - [Python API](http://doc.paddlepaddle.org/develop/doc/api/index_en.html)
Our new API enables much shorter programs. Our new API enables much shorter programs.
- [How to Contribute](http://www.paddlepaddle.org/develop/doc/howto/dev/contribute_to_paddle_en.html) - [How to Contribute](http://doc.paddlepaddle.org/develop/doc/howto/dev/contribute_to_paddle_en.html)
We appreciate your contributions! We appreciate your contributions!
## Ask Questions ## Ask Questions
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues). You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).
......
...@@ -15,23 +15,44 @@ ...@@ -15,23 +15,44 @@
set(CBLAS_FOUND OFF) set(CBLAS_FOUND OFF)
## Find MKL First. ## Find MKLML First.
set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs") if(WITH_MKLML AND MKLML_INC_DIR AND MKLML_LIB)
set(MKL_ROOT ${INTEL_ROOT}/mkl CACHE PATH "Folder contains MKL") set(CBLAS_FOUND ON)
set(CBLAS_PROVIDER MKLML)
set(CBLAS_INC_DIR ${MKLML_INC_DIR})
set(CBLAS_LIBRARIES ${MKLML_LIB})
add_definitions(-DPADDLE_USE_MKLML)
add_definitions(-DLAPACK_FOUND)
message(STATUS "Found cblas and lapack in MKLML "
"(include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
return()
endif()
## Then find MKL.
set(INTEL_MKL_ROOT "/opt/intel/mkl" CACHE PATH "Folder contains intel mkl libs")
set(MKL_ROOT $ENV{MKL_ROOT} CACHE PATH "Folder contains env MKL")
set(MKL_INCLUDE_SEARCH_PATHS
${MKL_ROOT}/include
${INTEL_MKL_ROOT}/include)
set(MKL_LIB_SEARCH_PATHS
${MKL_ROOT}/lib
${MKL_ROOT}/lib/intel64
${INTEL_MKL_ROOT}/lib
${INTEL_MKL_ROOT}/lib/intel64)
find_path(MKL_INC_DIR mkl.h PATHS find_path(MKL_INC_DIR mkl.h PATHS
${MKL_ROOT}/include) ${MKL_INCLUDE_SEARCH_PATHS})
find_path(MKL_LAPACK_INC_DIR mkl_lapacke.h PATHS find_path(MKL_LAPACK_INC_DIR mkl_lapacke.h PATHS
${MKL_ROOT}/include) ${MKL_INCLUDE_SEARCH_PATHS})
find_library(MKL_CORE_LIB NAMES mkl_core PATHS find_library(MKL_CORE_LIB NAMES mkl_core PATHS
${MKL_ROOT}/lib ${MKL_LIB_SEARCH_PATHS})
${MKL_ROOT}/lib/intel64)
find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS
${MKL_ROOT}/lib ${MKL_LIB_SEARCH_PATHS})
${MKL_ROOT}/lib/intel64)
find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS
${MKL_ROOT}/lib ${MKL_LIB_SEARCH_PATHS})
${MKL_ROOT}/lib/intel64)
if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64)
set(CBLAS_FOUND ON) set(CBLAS_FOUND ON)
......
...@@ -67,6 +67,30 @@ else() ...@@ -67,6 +67,30 @@ else()
include_directories(${CUDA_TOOLKIT_INCLUDE}) include_directories(${CUDA_TOOLKIT_INCLUDE})
endif(NOT WITH_GPU) endif(NOT WITH_GPU)
if(WITH_MKLDNN)
add_definitions(-DPADDLE_USE_MKLDNN)
if (WITH_MKLML AND MKLDNN_IOMP_DIR)
message(STATUS "Enable Intel OpenMP at ${MKLDNN_IOMP_DIR}")
set(OPENMP_FLAGS "-fopenmp")
set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}")
else()
find_package(OpenMP)
if(OPENMP_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
else()
message(WARNING "Can not find OpenMP."
"Some performance features in MKLDNN may not be available")
endif()
endif()
endif(WITH_MKLDNN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
...@@ -102,12 +126,19 @@ if(WITH_GOLANG) ...@@ -102,12 +126,19 @@ if(WITH_GOLANG)
message(FATAL_ERROR "no glide executeble found: $ENV{GOPATH}/bin/glide") message(FATAL_ERROR "no glide executeble found: $ENV{GOPATH}/bin/glide")
endif() endif()
add_custom_target(go_vendor) # this command will only run when the file it depends is missing
add_custom_command(TARGET go_vendor # or has changed, or the output is missing.
add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/glide
COMMAND env GOPATH=${GOPATH} ${GLIDE} install COMMAND env GOPATH=${GOPATH} ${GLIDE} install
COMMAND touch ${CMAKE_BINARY_DIR}/glide
DEPENDS ${PROJ_ROOT}/go/glide.lock
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go" WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go"
) )
add_dependencies(go_vendor go_path)
# depends on the custom command which outputs
# ${CMAKE_BINARY_DIR}/glide, the custom command does not need to
# run every time this target is built.
add_custom_target(go_vendor DEPENDS ${CMAKE_BINARY_DIR}/glide go_path)
endif() endif()
endif(WITH_GOLANG) endif(WITH_GOLANG)
...@@ -27,7 +27,8 @@ set(IGNORE_PATTERN ...@@ -27,7 +27,8 @@ set(IGNORE_PATTERN
.*cblas\\.h.* .*cblas\\.h.*
.*\\.pb\\.txt .*\\.pb\\.txt
.*LtrDataProvider.* .*LtrDataProvider.*
.*MultiDataProvider.*) .*MultiDataProvider.*
.*pb.*)
# add_style_check_target # add_style_check_target
# #
...@@ -52,14 +53,13 @@ macro(add_style_check_target TARGET_NAME) ...@@ -52,14 +53,13 @@ macro(add_style_check_target TARGET_NAME)
endif() endif()
endforeach() endforeach()
if(LINT MATCHES ON) if(LINT MATCHES ON)
# cpplint code style
get_filename_component(base_filename ${filename} NAME) get_filename_component(base_filename ${filename} NAME)
set(CUR_GEN ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.cpplint) set(CUR_GEN ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.cpplint)
add_custom_command(OUTPUT ${CUR_GEN} add_custom_command(TARGET ${TARGET_NAME} PRE_BUILD
PRE_BUILD COMMAND "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py"
COMMAND env ${py_env} "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py" "--filter=${STYLE_FILTER}"
"--filter=${STYLE_FILTER}" "--write-success=${CUR_GEN}" ${filename}
"--write-success=${CUR_GEN}" ${filename}
DEPENDS ${filename}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif() endif()
endforeach() endforeach()
......
...@@ -106,6 +106,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -106,6 +106,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
SET(CMAKE_SYSTEM_PROCESSOR armv7-a) SET(CMAKE_SYSTEM_PROCESSOR armv7-a)
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
SET(CMAKE_SYSTEM_PROCESSOR aarch64)
ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-") SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF() ENDIF()
...@@ -162,6 +166,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -162,6 +166,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}") STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}") STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}")
...@@ -186,6 +194,10 @@ ELSE() ...@@ -186,6 +194,10 @@ ELSE()
SET(CMAKE_ANDROID_STANDALONE_TOOLCHAIN ${ANDROID_STANDALONE_TOOLCHAIN}) SET(CMAKE_ANDROID_STANDALONE_TOOLCHAIN ${ANDROID_STANDALONE_TOOLCHAIN})
ENDIF() ENDIF()
SET(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ABI}) SET(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ABI})
SET(CMAKE_ANDROID_ARM_MODE ${ANDROID_ARM_MODE}) IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(CMAKE_ANDROID_ARM_NEON ${ANDROID_ARM_NEON}) SET(CMAKE_ANDROID_ARM_MODE ${ANDROID_ARM_MODE})
IF(ANDROID_ABI STREQUAL "armeabi-v7a")
SET(CMAKE_ANDROID_ARM_NEON ${ANDROID_ARM_NEON})
ENDIF()
ENDIF()
ENDIF() ENDIF()
...@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) ...@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048" GIT_TAG "master"
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -52,6 +52,7 @@ ExternalProject_Add( ...@@ -52,6 +52,7 @@ ExternalProject_Add(
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog) ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog) LIST(APPEND external_project_dependencies glog)
...@@ -34,9 +34,15 @@ IF(WITH_TESTING) ...@@ -34,9 +34,15 @@ IF(WITH_TESTING)
"${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE) "${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE)
ENDIF(WIN32) ENDIF(WIN32)
IF(WITH_MKLML)
# wait for mklml downloading completed
SET(GTEST_DEPENDS ${MKLML_PROJECT})
ENDIF()
ExternalProject_Add( ExternalProject_Add(
extern_gtest extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${GTEST_DEPENDS}
GIT_REPOSITORY "https://github.com/google/googletest.git" GIT_REPOSITORY "https://github.com/google/googletest.git"
GIT_TAG "release-1.8.0" GIT_TAG "release-1.8.0"
PREFIX ${GTEST_SOURCES_DIR} PREFIX ${GTEST_SOURCES_DIR}
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_MKLDNN})
return()
ENDIF(NOT ${WITH_MKLDNN})
INCLUDE(ExternalProject)
SET(MKLDNN_PROJECT "extern_mkldnn")
SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_ROOT ${CMAKE_INSTALL_PREFIX})
IF(NOT "$ENV{HOME}" STREQUAL "/root")
SET(MKLDNN_INSTALL_ROOT "$ENV{HOME}")
ENDIF()
SET(MKLDNN_INSTALL_DIR "${MKLDNN_INSTALL_ROOT}/opt/paddle/third_party/mkldnn")
SET(MKLDNN_INCLUDE_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
IF(WIN32)
MESSAGE(WARNING "It is not supported compiling with mkldnn in windows Paddle yet."
"Force WITH_MKLDNN=OFF")
SET(WITH_MKLDNN OFF)
return()
ELSE(WIN32)
SET(MKLDNN_LIBRARY "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE)
MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
#SET(CMAKE_MACOSX_RPATH 1) # hold for MacOS
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
ENDIF(WIN32)
INCLUDE_DIRECTORIES(${MKLDNN_INCLUDE_DIR})
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
SET(MKLDNN_MKLROOT ${MKLML_ROOT})
SET(MKLDNN_IOMP_LIB ${MKLML_IOMP_LIB})
SET(MKLDNN_IOMP_DIR ${MKLML_LIB_DIR})
ENDIF()
ExternalProject_Add(
${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "v0.9"
PREFIX ${MKLDNN_SOURCES_DIR}
CONFIGURE_COMMAND mkdir -p <SOURCE_DIR>/build
BUILD_COMMAND cd <SOURCE_DIR>/build
&& cmake .. -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} -DMKLROOT=${MKLDNN_MKLROOT}
&& $(MAKE)
INSTALL_COMMAND cd <SOURCE_DIR>/build && $(MAKE) install
UPDATE_COMMAND ""
)
ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIBRARY})
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIBRARY}")
LIST(APPEND external_project_dependencies mkldnn)
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_MKLML})
return()
ENDIF(NOT ${WITH_MKLML})
INCLUDE(ExternalProject)
SET(MKLML_PROJECT "extern_mklml")
SET(MKLML_VER "mklml_lnx_2018.0.20170425")
SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.9/${MKLML_VER}.tgz")
SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
SET(MKLML_DST_DIR "opt/paddle/third_party/mklml")
SET(MKLML_INSTALL_ROOT "${CMAKE_INSTALL_PREFIX}")
IF(NOT "$ENV{HOME}" STREQUAL "/root")
SET(MKLML_INSTALL_ROOT "$ENV{HOME}")
ENDIF()
SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
SET(MKLML_ROOT ${MKLML_INSTALL_DIR}/${MKLML_VER})
SET(MKLML_INC_DIR ${MKLML_ROOT}/include)
SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
INCLUDE_DIRECTORIES(${MKLML_INC_DIR})
SET(mklml_cmakefile ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt)
FILE(WRITE ${mklml_cmakefile} "PROJECT(MKLML)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${MKLML_VER}\n"
" DESTINATION ${MKLML_DST_DIR})\n")
ExternalProject_Add(
${MKLML_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${MKLML_SOURCE_DIR}
DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate -O ${MKLML_DOWNLOAD_DIR}/${MKLML_VER}.tgz ${MKLML_URL}
&& tar -xzf ${MKLML_DOWNLOAD_DIR}/${MKLML_VER}.tgz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLML_INSTALL_ROOT}
)
ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB})
ADD_DEPENDENCIES(mklml ${MKLML_PROJECT})
LIST(APPEND external_project_dependencies mklml)
...@@ -7,10 +7,24 @@ set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK") ...@@ -7,10 +7,24 @@ set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK")
find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include) find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include)
find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib) find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib)
find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib) find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib)
find_library(NNPACK_UKERNELS_LIB NAMES nnpack_ukernels PATHS ${NNPACK_ROOT}/lib)
find_library(NNPACK_CPUFEATURES_LIB NAMES cpufeatures PATHS ${NNPACK_ROOT}/lib)
if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB) if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB)
set(NNPACK_FOUND ON) set(NNPACK_FOUND ON)
INCLUDE_DIRECTORIES(${NNPACK_INC_DIR}) INCLUDE_DIRECTORIES(${NNPACK_INC_DIR})
set(NNPACK_LIBS)
list(APPEND NNPACK_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB})
if (NNPACK_UKERNELS_LIB)
list(APPEND NNPACK_LIBS ${NNPACK_UKERNELS_LIB})
endif()
if (NNPACK_CPUFEATURES_LIB)
list(APPEND NNPACK_LIBS ${NNPACK_CPUFEATURES_LIB})
endif()
if(NOT ANDROID)
list(APPEND NNPACK_LIBS "rt")
endif()
else() else()
message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})") message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})")
endif() endif()
...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND}) ...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND})
# arm_soft_fp_abi branch of OpenBLAS to support softfp # arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi # https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5") SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(TARGET "ARMV7")
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(TARGET "ARMV8")
ENDIF()
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI) ELSEIF(RPI)
# use hardfp # use hardfp
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")
......
INCLUDE(ExternalProject)
SET(PYBIND_SOURCE_DIR ${THIRD_PARTY_PATH}/pybind)
INCLUDE_DIRECTORIES(${PYBIND_SOURCE_DIR}/src/extern_pybind/include)
ExternalProject_Add(
extern_pybind
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/pybind/pybind11.git"
GIT_TAG "v2.1.1"
PREFIX ${PYBIND_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/pybind_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(pybind STATIC ${dummyfile})
else()
add_library(pybind INTERFACE)
endif()
add_dependencies(pybind extern_pybind)
LIST(APPEND external_project_dependencies pybind)
...@@ -18,6 +18,9 @@ INCLUDE(python_module) ...@@ -18,6 +18,9 @@ INCLUDE(python_module)
FIND_PACKAGE(PythonInterp 2.7) FIND_PACKAGE(PythonInterp 2.7)
IF(WITH_PYTHON) IF(WITH_PYTHON)
FIND_PACKAGE(PythonLibs 2.7) FIND_PACKAGE(PythonLibs 2.7)
# Fixme: Maybe find a static library. Get SHARED/STATIC by FIND_PACKAGE.
ADD_LIBRARY(python SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET python PROPERTY IMPORTED_LOCATION ${PYTHON_LIBRARIES})
ENDIF(WITH_PYTHON) ENDIF(WITH_PYTHON)
SET(py_env "") SET(py_env "")
......
...@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag) ...@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif() endif()
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers. # Apple Clang is a different compiler than upstream Clang which havs different version numbers.
...@@ -109,7 +114,9 @@ set(COMMON_FLAGS ...@@ -109,7 +114,9 @@ set(COMMON_FLAGS
-Wno-unused-function -Wno-unused-function
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=sign-compare -Wno-error=sign-compare
-Wno-error=unused-local-typedefs) -Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in Pybind11
)
set(GPU_COMMON_FLAGS set(GPU_COMMON_FLAGS
-fPIC -fPIC
...@@ -122,6 +129,7 @@ set(GPU_COMMON_FLAGS ...@@ -122,6 +129,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=unused-local-typedefs -Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
) )
if (APPLE) if (APPLE)
...@@ -150,7 +158,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) ...@@ -150,7 +158,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")
......
...@@ -90,10 +90,11 @@ ...@@ -90,10 +90,11 @@
# including binary directory for generated headers. # including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE AND NOT ANDROID)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
endif(NOT APPLE) set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE AND NOT ANDROID)
function(merge_static_libs TARGET_NAME) function(merge_static_libs TARGET_NAME)
set(libs ${ARGN}) set(libs ${ARGN})
...@@ -103,6 +104,7 @@ function(merge_static_libs TARGET_NAME) ...@@ -103,6 +104,7 @@ function(merge_static_libs TARGET_NAME)
foreach(lib ${libs}) foreach(lib ${libs})
list(APPEND libs_deps ${${lib}_LIB_DEPENDS}) list(APPEND libs_deps ${${lib}_LIB_DEPENDS})
endforeach() endforeach()
list(REMOVE_DUPLICATES libs_deps)
if(APPLE) # Use OSX's libtool to merge archives if(APPLE) # Use OSX's libtool to merge archives
# To produce a library we need at least one source file. # To produce a library we need at least one source file.
...@@ -126,7 +128,7 @@ function(merge_static_libs TARGET_NAME) ...@@ -126,7 +128,7 @@ function(merge_static_libs TARGET_NAME)
# Get the file names of the libraries to be merged # Get the file names of the libraries to be merged
set(libfiles ${libfiles} $<TARGET_FILE:${lib}>) set(libfiles ${libfiles} $<TARGET_FILE:${lib}>)
endforeach() endforeach()
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a"
COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles}) COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles})
else() # general UNIX: use "ar" to extract objects and re-add to a common lib else() # general UNIX: use "ar" to extract objects and re-add to a common lib
...@@ -144,11 +146,11 @@ function(merge_static_libs TARGET_NAME) ...@@ -144,11 +146,11 @@ function(merge_static_libs TARGET_NAME)
DEPENDS ${lib} ${objdir} DEPENDS ${lib} ${objdir}
WORKING_DIRECTORY ${objdir}) WORKING_DIRECTORY ${objdir})
# Empty dummy source file that goes into merged library # Empty dummy source file that goes into merged library
set(mergebase ${lib}.mergebase.c) set(mergebase ${lib}.mergebase.c)
add_custom_command(OUTPUT ${mergebase} add_custom_command(OUTPUT ${mergebase}
COMMAND ${CMAKE_COMMAND} -E touch ${mergebase} COMMAND ${CMAKE_COMMAND} -E touch ${mergebase}
DEPENDS ${objlistfile}) DEPENDS ${objlistfile})
list(APPEND mergebases "${mergebase}") list(APPEND mergebases "${mergebase}")
endforeach() endforeach()
...@@ -183,6 +185,10 @@ function(cc_library TARGET_NAME) ...@@ -183,6 +185,10 @@ function(cc_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
endif() endif()
# cpplint code style
add_style_check_target(${TARGET_NAME} ${cc_library_SRCS})
else(cc_library_SRCS) else(cc_library_SRCS)
if (cc_library_DEPS) if (cc_library_DEPS)
merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) merge_static_libs(${TARGET_NAME} ${cc_library_DEPS})
...@@ -284,8 +290,22 @@ function(go_library TARGET_NAME) ...@@ -284,8 +290,22 @@ function(go_library TARGET_NAME)
set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
endif() endif()
# Add dummy code to support `make target_name` under Terminal Command
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
# This custom command will always run since it depends on a not
# existing file.
add_custom_command(
OUTPUT dummy_rebulid_${TARGET_NAME}
COMMAND cmake -E touch ${dummyfile}
)
# Create a custom target that depends on the custom command output
# file, so the custom command can be referenced as a dependency by
# `add_dependencies`.
add_custom_target(rebuild_${TARGET_NAME}
DEPENDS dummy_rebulid_${TARGET_NAME}
)
# Add dummy code to support `make target_name` under Terminal Command
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
if (go_library_SHARED OR go_library_shared) if (go_library_SHARED OR go_library_shared)
add_library(${TARGET_NAME} SHARED ${dummyfile}) add_library(${TARGET_NAME} SHARED ${dummyfile})
...@@ -296,11 +316,17 @@ function(go_library TARGET_NAME) ...@@ -296,11 +316,17 @@ function(go_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${go_library_DEPS}) add_dependencies(${TARGET_NAME} ${go_library_DEPS})
endif(go_library_DEPS) endif(go_library_DEPS)
# The "source file" of the library is `${dummyfile}` which never
# change, so the target will never rebuild. Make the target depends
# on the custom command that touches the library "source file", so
# rebuild will always happen.
add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME})
set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}")
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${${TARGET_NAME}_LIB_PATH}" COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Golang build source code # Golang build source code
...@@ -308,7 +334,7 @@ function(go_library TARGET_NAME) ...@@ -308,7 +334,7 @@ function(go_library TARGET_NAME)
-o "${${TARGET_NAME}_LIB_PATH}" -o "${${TARGET_NAME}_LIB_PATH}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}"
# must run under GOPATH # must run under GOPATH
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_dependencies(${TARGET_NAME} go_vendor) add_dependencies(${TARGET_NAME} go_vendor)
endfunction(go_library) endfunction(go_library)
...@@ -319,14 +345,11 @@ function(go_binary TARGET_NAME) ...@@ -319,14 +345,11 @@ function(go_binary TARGET_NAME)
cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
# TODO: don't know what ${TARGET_NAME}_link does
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin) install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary) endfunction(go_binary)
...@@ -334,15 +357,18 @@ endfunction(go_binary) ...@@ -334,15 +357,18 @@ endfunction(go_binary)
function(go_test TARGET_NAME) function(go_test TARGET_NAME)
set(options OPTIONAL) set(options OPTIONAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs DEPS)
cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test -race
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_test_SRCS} ".${CMAKE_CURRENT_SOURCE_REL_DIR}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
function(proto_library TARGET_NAME) function(proto_library TARGET_NAME)
......
...@@ -104,6 +104,11 @@ cross_channel_norm ...@@ -104,6 +104,11 @@ cross_channel_norm
------------------ ------------------
.. autoclass:: paddle.v2.layer.cross_channel_norm .. autoclass:: paddle.v2.layer.cross_channel_norm
:noindex: :noindex:
row_l2_norm
-----------
.. autoclass:: paddle.v2.layer.row_l2_norm
:noindex:
Recurrent Layers Recurrent Layers
================ ================
...@@ -198,6 +203,10 @@ identity_projection ...@@ -198,6 +203,10 @@ identity_projection
.. autoclass:: paddle.v2.layer.identity_projection .. autoclass:: paddle.v2.layer.identity_projection
:noindex: :noindex:
slice_projection
-------------------
.. autoclass:: paddle.v2.layer.slice_projection
:noindex:
table_projection table_projection
---------------- ----------------
...@@ -316,6 +325,11 @@ scaling ...@@ -316,6 +325,11 @@ scaling
.. autoclass:: paddle.v2.layer.scaling .. autoclass:: paddle.v2.layer.scaling
:noindex: :noindex:
clip
----
.. autoclass:: paddle.v2.layer.clip
:noindex:
slope_intercept slope_intercept
--------------- ---------------
.. autoclass:: paddle.v2.layer.slope_intercept .. autoclass:: paddle.v2.layer.slope_intercept
...@@ -474,6 +488,11 @@ prelu ...@@ -474,6 +488,11 @@ prelu
.. autoclass:: paddle.v2.layer.prelu .. autoclass:: paddle.v2.layer.prelu
:noindex: :noindex:
gated_unit
-----------
.. autoclass:: paddle.v2.layer.gated_unit
:noindex:
Detection output Layer Detection output Layer
====================== ======================
......
...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future. ...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election ### Trainer Election
One trainer will be elected as the one to save the model. When using One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to etcd, trainer ID is a randomly generated UUID, the trainer will
elect one trainer. When not using etcd, unique trainer IDs will be contact the master server requesting to save the model, and find out
given by the administrator, the trainer whose ID is "0" is elected to if itself is elected. When the master server is not used, unique
save the model. trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path ### Model Save Path
......
...@@ -37,8 +37,8 @@ Scope is an association of a name to variable. All variables belong to `Scope`. ...@@ -37,8 +37,8 @@ Scope is an association of a name to variable. All variables belong to `Scope`.
```cpp ```cpp
class Scope { class Scope {
public: public:
Variable* CreateVariable(const std::string& name); Variable* NewVar(const std::string& name);
const Variable* GetVariable(const std::string& name) const; const Variable* FindVar(const std::string& name) const;
private: private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
...@@ -58,12 +58,12 @@ class Scope { ...@@ -58,12 +58,12 @@ class Scope {
public: public:
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {} Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const { Variable* FindVar(const std::string& name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
return it->second.get(); return it->second.get();
} else if (parent_ != nullptr) { } else if (parent_ != nullptr) {
return parent_->GetVariable(name); return parent_->FindVar(name);
} else { } else {
return nullptr; return nullptr;
} }
...@@ -95,10 +95,10 @@ class Scope { ...@@ -95,10 +95,10 @@ class Scope {
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr); static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);
// return nullptr if not found. // return nullptr if not found.
Variable* GetVariable(const std::string& name) const; Variable* FindVar(const std::string& name) const;
// return if already contains same name variable. // return if already contains same name variable.
Variable* CreateVariable(const std::string& name); Variable* NewVar(const std::string& name);
private: private:
std::shared_ptr<Scope> parent_; std::shared_ptr<Scope> parent_;
...@@ -107,11 +107,11 @@ class Scope { ...@@ -107,11 +107,11 @@ class Scope {
``` ```
## Only scope can create a variable ## Only scope can create a variable
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`. To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `NewVar` can construct `Variable`.
## When scope destroyed, all variables inside this scope should be destroyed together ## When scope destroyed, all variables inside this scope should be destroyed together
The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together. The scope hold unique pointers for all variables. User can `FindVar` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
## Sharing a parent scope ## Sharing a parent scope
...@@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar ...@@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar
## Orthogonal interface ## Orthogonal interface
`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily. `FindVar` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `NewVar` will return a `Error` when there is a name conflict locally. Combine `FindVar` and `NewVar`, we can implement `NewVar` easily.
...@@ -49,6 +49,7 @@ message AttrProto { ...@@ -49,6 +49,7 @@ message AttrProto {
message VarProto { message VarProto {
required string name = 1; required string name = 1;
required string comment = 2; required string comment = 2;
required bool is_tensor = 3;
}; };
message OpProto { message OpProto {
......
...@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异 ...@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。 * 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习律或者对数据进行归一化处理。 主要的解决办法是减小学习律或者对数据进行归一化处理。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
pip uninstall py_paddle paddle
然后安装paddle的python环境, 在build目录下执行
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
\frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x} \frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x}
假设 :math:`z = f(W^T x + b)` ,那么 假设 :math:`z = W^T x + b` ,那么
.. math:: .. math::
......
...@@ -29,7 +29,7 @@ Fully connected layer takes a dense input vector with dimension :math:`D_i`. It ...@@ -29,7 +29,7 @@ Fully connected layer takes a dense input vector with dimension :math:`D_i`. It
where :math:`f(.)` is an nonlinear *activation* function, such as sigmoid, tanh, and Relu. where :math:`f(.)` is an nonlinear *activation* function, such as sigmoid, tanh, and Relu.
The transformation matrix :math:`W` and bias vector :math:`b` are the *parameters* of the layer. The *parameters* of a layer are learned during training in the *backward pass*. The backward pass computes the gradients of the output function with respect to all parameters and inputs. The optimizer can use chain rule to compute the gradients of the loss function with respect to each parameter. The transformation matrix :math:`W` and bias vector :math:`b` are the *parameters* of the layer. The *parameters* of a layer are learned during training in the *backward pass*. The backward pass computes the gradients of the output function with respect to all parameters and inputs. The optimizer can use chain rule to compute the gradients of the loss function with respect to each parameter.
Suppose our loss function is :math:`c(y)`, then Suppose our loss function is :math:`c(y)`, then
...@@ -37,7 +37,7 @@ Suppose our loss function is :math:`c(y)`, then ...@@ -37,7 +37,7 @@ Suppose our loss function is :math:`c(y)`, then
\frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x} \frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x}
Suppose :math:`z = f(W^T x + b)`, then Suppose :math:`z = W^T x + b`, then
.. math:: .. math::
...@@ -48,7 +48,7 @@ This derivative can be automatically computed by our base layer class. ...@@ -48,7 +48,7 @@ This derivative can be automatically computed by our base layer class.
Then, for fully connected layer, we need to compute: Then, for fully connected layer, we need to compute:
.. math:: .. math::
\frac{\partial z}{\partial x} = W, \frac{\partial z_j}{\partial W_{ij}} = x_i, \frac{\partial z}{\partial b} = \mathbf 1 \frac{\partial z}{\partial x} = W, \frac{\partial z_j}{\partial W_{ij}} = x_i, \frac{\partial z}{\partial b} = \mathbf 1
where :math:`\mathbf 1` is an all one vector, :math:`W_{ij}` is the number at the i-th row and j-th column of the matrix :math:`W`, :math:`z_j` is the j-th component of the vector :math:`z`, and :math:`x_i` is the i-th component of the vector :math:`x`. where :math:`\mathbf 1` is an all one vector, :math:`W_{ij}` is the number at the i-th row and j-th column of the matrix :math:`W`, :math:`z_j` is the j-th component of the vector :math:`z`, and :math:`x_i` is the i-th component of the vector :math:`x`.
...@@ -322,7 +322,7 @@ All the gradient check unit tests are located in :code:`paddle/gserver/tests/tes ...@@ -322,7 +322,7 @@ All the gradient check unit tests are located in :code:`paddle/gserver/tests/tes
/* weight */ true); /* weight */ true);
} }
} }
If you are creating a new file for the test, such as :code:`paddle/gserver/tests/testFCGrad.cpp`, you need to add the file to :code:`paddle/gserver/tests/CMakeLists.txt`. An example is given below. All the unit tests will run when you execute the command :code:`make tests`. Notice that some layers might need high accuracy for the gradient check unit tests to work well. You need to configure :code:`WITH_DOUBLE` to `ON` when configuring cmake. If you are creating a new file for the test, such as :code:`paddle/gserver/tests/testFCGrad.cpp`, you need to add the file to :code:`paddle/gserver/tests/CMakeLists.txt`. An example is given below. All the unit tests will run when you execute the command :code:`make tests`. Notice that some layers might need high accuracy for the gradient check unit tests to work well. You need to configure :code:`WITH_DOUBLE` to `ON` when configuring cmake.
.. code-block:: bash .. code-block:: bash
......
...@@ -41,7 +41,7 @@ PaddlePaddle文档需要准备的环境相对较复杂,所以我们推荐使 ...@@ -41,7 +41,7 @@ PaddlePaddle文档需要准备的环境相对较复杂,所以我们推荐使
python -c "import py_paddle" python -c "import py_paddle"
如果提示错误,那么用户需要在本地编译安装PaddlePaddle,请参考 `源码编译文档 <http://www.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html>`_ 。 如果提示错误,那么用户需要在本地编译安装PaddlePaddle,请参考 `源码编译文档 <http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html>`_ 。
注意,用户在首次编译安装PaddlePaddle时,请将WITH_DOC选项关闭。在编译安装正确之后,请再次确认py_paddle包已经安装,即可进行下一步操作。 注意,用户在首次编译安装PaddlePaddle时,请将WITH_DOC选项关闭。在编译安装正确之后,请再次确认py_paddle包已经安装,即可进行下一步操作。
如果提示正确,可以执行以下命令编译生成文档,即 如果提示正确,可以执行以下命令编译生成文档,即
...@@ -68,9 +68,9 @@ PaddlePaddle文档使用 `sphinx`_ 自动生成,用户可以参考sphinx教程 ...@@ -68,9 +68,9 @@ PaddlePaddle文档使用 `sphinx`_ 自动生成,用户可以参考sphinx教程
如何更新www.paddlepaddle.org文档 如何更新www.paddlepaddle.org文档
================================ ================================
开发者给PaddlePaddle代码增加的注释以PR的形式提交到github中,提交方式可参见 `贡献文档 <http://paddlepaddle.org/develop/doc_cn/howto/dev/contribute_to_paddle_cn.html>`_ 。 开发者给PaddlePaddle代码增加的注释以PR的形式提交到github中,提交方式可参见 `贡献文档 <http://doc.paddlepaddle.org/develop/doc_cn/howto/dev/contribute_to_paddle_cn.html>`_ 。
目前PaddlePaddle的develop分支的文档是自动触发更新的,用户可以分别查看最新的 `中文文档 <http://www.paddlepaddle.org/develop/doc_cn/>`_ 和 目前PaddlePaddle的develop分支的文档是自动触发更新的,用户可以分别查看最新的 `中文文档 <http://doc.paddlepaddle.org/develop/doc_cn/>`_ 和
`英文文档 <http://www.paddlepaddle.org/develop/doc/>`_ 。 `英文文档 <http://doc.paddlepaddle.org/develop/doc/>`_ 。
......
...@@ -17,3 +17,7 @@ add_subdirectory(pserver/client/c) ...@@ -17,3 +17,7 @@ add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver) add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master) add_subdirectory(cmd/master)
add_subdirectory(master/c) add_subdirectory(master/c)
add_subdirectory(master)
add_subdirectory(pserver)
add_subdirectory(pserver/client)
add_subdirectory(utils/networkhelper)
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
go_binary(master SRC master.go DEPS paddle_go_optimizer) go_binary(master SRC master.go)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
...@@ -5,12 +19,15 @@ import ( ...@@ -5,12 +19,15 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
...@@ -20,11 +37,18 @@ func main() { ...@@ -20,11 +37,18 @@ func main() {
port := flag.Int("port", 8080, "port of the master server.") port := flag.Int("port", 8080, "port of the master server.")
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
level, e := log.ParseLevel(*logLevel)
candy.Must(e)
log.SetLevel(level)
if *endpoints == "" { if *endpoints == "" {
log.Warningln("-endpoints not set, fault tolerance not be enabled.") log.Warningln("-endpoints not set, fault tolerance not be enabled.")
} }
...@@ -46,6 +70,20 @@ func main() { ...@@ -46,6 +70,20 @@ func main() {
store = &master.InMemStore{} store = &master.InMemStore{}
} }
shutdown := func() {
log.Infoln("shutting down gracefully")
err := store.Shutdown()
if err != nil {
log.Errorln(err)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
...@@ -62,8 +100,12 @@ func main() { ...@@ -62,8 +100,12 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
err = http.Serve(l, nil) go func() {
if err != nil { err = http.Serve(l, nil)
log.Fatal(err) if err != nil {
} log.Fatal(err)
}
}()
<-c
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
...@@ -18,49 +35,70 @@ func main() { ...@@ -18,49 +35,70 @@ func main() {
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
level, err := log.ParseLevel(*logLevel) level, err := log.ParseLevel(*logLevel)
if err != nil { candy.Must(err)
panic(err)
}
log.SetLevel(level) log.SetLevel(level)
var idx int var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
timeout := time.Second * time.Duration((*etcdTimeout)) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) idx, err = e.Register(*port)
idx, err = e.Register() candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
if err != nil { if err != nil {
panic(err) if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.")
} else {
log.Errorf("Fetch checkpoint failed, %s", err)
}
} }
} }
s, err := pserver.NewService(idx) shutdown := func() {
if err != nil { log.Infoln("shutting down gracefully")
panic(err) sErr := e.Shutdown()
if sErr != nil {
log.Errorln(sErr)
}
} }
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err)
err = rpc.Register(s) err = rpc.Register(s)
if err != nil { candy.Must(err)
panic(err)
}
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { candy.Must(err)
panic(err)
}
log.Infof("start pserver at port %d", *port) go func() {
err = http.Serve(l, nil) log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil)
candy.Must(err)
}()
if err != nil { <-c
panic(err)
}
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connection package connection
import ( import (
......
hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7 hash: 2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c
updated: 2017-06-27T14:05:48.925262819+08:00 updated: 2017-07-29T07:34:48.722757905+08:00
imports: imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7 version: c31bec0f29facff13f7c3e3d948e55dd6689ed42
subpackages: subpackages:
- alarm
- auth
- auth/authpb - auth/authpb
- client
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes - etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb - etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb - mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf - name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7 version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages: subpackages:
...@@ -17,12 +107,61 @@ imports: ...@@ -17,12 +107,61 @@ imports:
- proto - proto
- name: github.com/golang/snappy - name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9 version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag - name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04 version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio - name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129 version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 1351f936d976c60a0a48d728281922cf63eafb8d
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- name: golang.org/x/net - name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2 version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages: subpackages:
...@@ -34,11 +173,15 @@ imports: ...@@ -34,11 +173,15 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733 version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
repo: https://github.com/golang/sys.git
vcs: git
subpackages: subpackages:
- unix - unix
- name: golang.org/x/text - name: golang.org/x/text
version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77 version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages: subpackages:
- secure/bidirule - secure/bidirule
- transform - transform
...@@ -58,4 +201,23 @@ imports: ...@@ -58,4 +201,23 @@ imports:
- stats - stats
- tap - tap
- transport - transport
testImports: [] - name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
...@@ -6,7 +6,19 @@ import: ...@@ -6,7 +6,19 @@ import:
subpackages: subpackages:
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- embed
- etcdserver
- package: github.com/namsral/flag - package: github.com/namsral/flag
version: ^1.7.4-pre version: ^1.7.4-pre
- package: github.com/sirupsen/logrus - package: github.com/sirupsen/logrus
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy
- package: golang.org/x/crypto
vcs: git
repo: https://github.com/golang/crypto.git
- package: golang.org/x/sys
vcs: git
repo: https://github.com/golang/sys.git
- package: golang.org/x/text
vcs: git
repo: https://github.com/golang/text.git
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(master_test)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
go_library(paddle_master SHARED DEPS paddle_go_optimizer) go_library(paddle_master SHARED DEPS paddle_go_optimizer)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#define PADDLE_MASTER_OK 0 #define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1 #define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client; typedef int paddle_master_client;
*/ */
import "C" import "C"
...@@ -19,11 +35,9 @@ import ( ...@@ -19,11 +35,9 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client) var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client var curHandle C.paddle_master_client
...@@ -52,32 +66,32 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -52,32 +66,32 @@ func remove(client C.paddle_master_client) *master.Client {
} }
//export paddle_new_etcd_master_client //export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client { func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints) p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{ endpoints := strings.Split(p, ",")
Endpoints: strings.Split(p, ","), c, err := master.NewClient(
DialTimeout: time.Second * time.Duration(timeout), master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
}) master.WithBuffer(bufSize),
if err != nil { )
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil { if err != nil {
panic(err) panic(err)
} }
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
ch := make(chan string, 1) c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
ch <- a if err != nil {
c := master.NewClient(ch, bufSize) panic(err)
}
return add(c) return add(c)
} }
...@@ -86,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) { ...@@ -86,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client) remove(client)
} }
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset //export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client) c := get(client)
...@@ -104,23 +124,28 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -104,23 +124,28 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK return C.PADDLE_MASTER_OK
} }
// return value: // paddle_next_record gets the nexts training record.
// 0:ok //
// -1:error // returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r, err := c.NextRecord() r, err := c.NextRecord()
if err != nil { if err != nil {
// Error // NOTE: use errors to indicate pass ends
// TODO: return the type of error? if err.Error() == master.ErrAllTaskFailed.Error() ||
*record = (*C.uchar)(nullPtr) err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1 return -1
} }
if len(r) == 0 { if len(r) == 0 {
// Empty record // Empty record
*record = (*C.uchar)(nullPtr) *record = (*C.uchar)(nil)
return 0 return 0
} }
...@@ -130,6 +155,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { ...@@ -130,6 +155,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return C.int(size) return C.int(size)
} }
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free //export mem_free
func mem_free(p unsafe.Pointer) { func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so // "free" may be a better name for this function, but doing so
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
"os" "os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan record ch chan record
bufSize int
} }
type record struct { type record struct {
...@@ -19,33 +36,104 @@ type record struct { ...@@ -19,33 +36,104 @@ type record struct {
err error err error
} }
// NewClient creates a new Client. // WithBuffer sets the client to buffer the training record.
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addrCh <-chan string, bufSize int) *Client { func WithBuffer(bufSize int) func(*Client) error {
return func(c *Client) error {
if bufSize <= 0 {
return nil
}
c.bufSize = bufSize
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
if err != nil {
return err
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh) for _, opt := range opts {
go c.getRecords() err := opt(c)
return c if err != nil {
return nil, err
}
}
c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil
} }
func (c *Client) getRecords() { // StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}
func (c *Client) getRecords(passID int) {
for { for {
t, err := c.getTask() t, err := c.getTask(passID)
if err != nil { if err != nil {
// TODO(helin): wait before move on with next if err.Error() == ErrPassBefore.Error() ||
// getTask call. err.Error() == ErrNoMoreAvailable.Error() ||
log.Errorln(err) err.Error() == ErrAllTaskFailed.Error() {
continue c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue
}
log.Errorf("getTask error: %s", err)
} }
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
if err != nil { if e != nil {
log.Errorln(err) log.Errorln(e)
continue continue
} }
...@@ -68,7 +156,10 @@ func (c *Client) getRecords() { ...@@ -68,7 +156,10 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data // We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly // instance of the task is read. This is not exactly
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
c.taskFinished(t.ID) err = c.taskFinished(t.Meta.ID)
if err != nil {
log.Errorln(err)
}
} }
} }
...@@ -98,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) { ...@@ -98,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
} }
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset sets dataset to dispatch for the master server.
// //
// SetDataset can be call multiple times from different nodes. But // SetDataset can be call multiple times at one pass. But only the first call
// only the first call will be honored. // will be honored.
//
// After all tasks are done, another call of SetDataset will start another pass.
func (c *Client) SetDataset(globPaths []string) error { func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil) err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
} }
// getTask gets a new task from the master server. // getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) { func (c *Client) getTask(passID int) (Task, error) {
var t Task var t Task
err := c.conn.Call("Service.GetTask", 0, &t) err := c.conn.Call("Service.GetTask", passID, &t)
return t, err return t, err
} }
...@@ -118,6 +212,11 @@ func (c *Client) taskFinished(taskID int) error { ...@@ -118,6 +212,11 @@ func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil) return c.conn.Call("Service.TaskFinished", taskID, nil)
} }
// TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFailed", meta, nil)
}
// NextRecord returns next record in the dataset. // NextRecord returns next record in the dataset.
// //
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
...@@ -126,3 +225,11 @@ func (c *Client) NextRecord() ([]byte, error) { ...@@ -126,3 +225,11 @@ func (c *Client) NextRecord() ([]byte, error) {
r := <-c.ch r := <-c.ch
return r.r, r.err return r.r, r.err
} }
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
...@@ -40,22 +54,22 @@ func TestGetFinishTask(t *testing.T) { ...@@ -40,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
server := rpc.NewServer() server := rpc.NewServer()
err = server.Register(s) sErr = server.Register(s)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server) mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux) sErr = http.Serve(l, mux)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
}(l) }(l)
...@@ -66,11 +80,21 @@ func TestGetFinishTask(t *testing.T) { ...@@ -66,11 +80,21 @@ func TestGetFinishTask(t *testing.T) {
for i := 0; i < totalTask*chunkPerTask; i++ { for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1) w := recordio.NewWriter(f, -1, -1)
w.Write(nil) _, err = w.Write(nil)
if err != nil {
panic(err)
}
// call Close to force RecordIO writing a chunk. // call Close to force RecordIO writing a chunk.
w.Close() err = w.Close()
if err != nil {
panic(err)
}
}
err = f.Close()
if err != nil {
panic(err)
} }
f.Close()
// Manually intialize client to avoid calling c.getRecords() // Manually intialize client to avoid calling c.getRecords()
c := &Client{} c := &Client{}
...@@ -79,42 +103,56 @@ func TestGetFinishTask(t *testing.T) { ...@@ -79,42 +103,56 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1) ch := make(chan string, 1)
ch <- addr ch <- addr
go c.monitorMaster(ch) go c.monitorMaster(ch)
c.SetDataset([]string{path})
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask() task, cErr := c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("error: %v, pass: %d\n", cErr, i)
} }
tasks = append(tasks, task) tasks = append(tasks, task)
} }
_, err = c.getTask() // getting task before task finishes should return error
if err == nil { _, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i) t.Fatalf("Should get error, pass: %d\n", i)
} }
err = c.taskFinished(tasks[0].ID) cErr = c.taskFinished(tasks[0].Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", cErr, i)
} }
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
tasks = tasks[1:] tasks = tasks[1:]
task, err := c.getTask() _, cErr = c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatal(err) t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
} }
tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.taskFinished(task.ID) cErr = c.taskFinished(task.Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatal(cErr)
} }
} }
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i) checkOnePass(i)
} }
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master_test package master_test
import ( import (
...@@ -6,8 +20,10 @@ import ( ...@@ -6,8 +20,10 @@ import (
"net/http" "net/http"
"net/rpc" "net/rpc"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -15,6 +31,18 @@ import ( ...@@ -15,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) { func TestNextRecord(t *testing.T) {
const ( const (
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
...@@ -31,7 +59,7 @@ func TestNextRecord(t *testing.T) { ...@@ -31,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -55,32 +83,67 @@ func TestNextRecord(t *testing.T) { ...@@ -55,32 +83,67 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
w := recordio.NewWriter(f, -1, -1) w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
w.Write([]byte{byte(i)}) _, err = w.Write([]byte{byte(i)})
if err != nil {
panic(err)
}
}
err = w.Close()
if err != nil {
panic(err)
} }
w.Close()
f.Close()
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}
if len(r) != 1 { err = f.Close()
t.Fatal(pass, i, "Length should be 1.", r) if err != nil {
panic(err)
}
// start several client to test task fetching
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
} }
e = c.SetDataset([]string{path})
if e != nil {
panic(e)
}
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
if received[r[0]] { received := make(map[byte]bool)
t.Fatal(pass, i, "Received duplicate.", received, r) taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
} }
received[r[0]] = true }()
}
} }
wg.Wait()
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
...@@ -25,15 +39,12 @@ type EtcdClient struct { ...@@ -25,15 +39,12 @@ type EtcdClient struct {
statePath string statePath string
client *clientv3.Client client *clientv3.Client
lock *concurrency.Mutex lock *concurrency.Mutex
sess *concurrency.Session
} }
// NewEtcdClient creates a new EtcdClient. // NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints) log.Debugf("Connecting to etcd at %v", endpoints)
// TODO(helin): gracefully shutdown etcd store. Becuase etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
DialTimeout: dialTimeout, DialTimeout: dialTimeout,
...@@ -53,14 +64,14 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -53,14 +64,14 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath) log.Infof("Trying to acquire lock at %s.", lockPath)
err = lock.Lock(context.TODO()) err = lock.Lock(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("Successfully acquired lock at %s.", lockPath) log.Infof("Successfully acquired lock at %s.", lockPath)
put := clientv3.OpPut(addrPath, string(addr)) put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -75,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -75,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
statePath: statePath, statePath: statePath,
client: cli, client: cli,
lock: lock, lock: lock,
sess: sess,
} }
return e, nil return e, nil
...@@ -143,9 +155,24 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -143,9 +155,24 @@ func (e *EtcdClient) Load() ([]byte, error) {
return state, nil return state, nil
} }
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
err := e.sess.Close()
newErr := e.client.Close()
if newErr != nil {
if err == nil {
err = newErr
} else {
log.Errorln(newErr)
}
}
return err
}
// GetKey gets the value by the specify key. // GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key) resp, err := c.Get(ctx, key)
cancel() cancel()
if err != nil { if err != nil {
...@@ -159,8 +186,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { ...@@ -159,8 +186,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return string(v), nil return string(v), nil
} }
// WatchKey watches the specify key and send to valChan if there is some event. // watchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) { func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key) rch := c.Watch(context.Background(), key)
for wresp := range rch { for wresp := range rch {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "sync" import "sync"
// InMemStore is an in memory implementation of Store interface. // InMemStore is an in memory implementation of Store interface.
// //
// It does not tolerate the fault that casues the program to crash. // It does not tolerate the fault that causes the program to crash.
type InMemStore struct { type InMemStore struct {
mu sync.Mutex mu sync.Mutex
buf []byte buf []byte
...@@ -26,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) { ...@@ -26,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
return m.buf, nil return m.buf, nil
} }
// Shutdown shuts down the in mem store.
func (m *InMemStore) Shutdown() error {
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
...@@ -5,6 +19,7 @@ import ( ...@@ -5,6 +19,7 @@ import (
"compress/gzip" "compress/gzip"
"encoding/gob" "encoding/gob"
"errors" "errors"
"math/rand"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
...@@ -19,10 +34,23 @@ const ( ...@@ -19,10 +34,23 @@ const (
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
) )
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state. // Store is the interface for save and load the master state.
type Store interface { type Store interface {
Save([]byte) error Save([]byte) error
Load() ([]byte, error) Load() ([]byte, error)
Shutdown() error
} }
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
...@@ -31,40 +59,56 @@ type Chunk struct { ...@@ -31,40 +59,56 @@ type Chunk struct {
Index recordio.Index // chunk index Index recordio.Index // chunk index
} }
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}
// Task is the basic unit of data instances assigned to trainers. // Task is the basic unit of data instances assigned to trainers.
type Task struct { type Task struct {
ID int Meta TaskMeta
Chunks []Chunk Chunks []Chunk
} }
type taskEntry struct { type taskEntry struct {
Epoch int Task Task
NumTimeout int // A task fails if it's timeout or trainer reports it exits unnormally.
Task Task NumFailure int
} }
type taskQueues struct { type taskQueues struct {
Todo []taskEntry Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry Done []taskEntry
Failed []Task Failed []taskEntry
} }
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
timeoutMax int failureMax int
ready chan struct{}
store Store store Store
ready chan struct{}
initDone bool
mu sync.Mutex mu sync.Mutex
initDone bool
taskQueues taskQueues taskQueues taskQueues
currPass int
jobTasks []taskEntry
savingTrainer string
} }
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0 // generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 { if chunksPerTask <= 0 {
chunksPerTask = 1 chunksPerTask = 1
} }
...@@ -73,8 +117,9 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -73,8 +117,9 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
var cur taskEntry var cur taskEntry
for i, c := range chunks { for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
id++ counter++
id = timestamp + randStart + counter
result = append(result, cur) result = append(result, cur)
cur.Task.Chunks = nil cur.Task.Chunks = nil
} }
...@@ -83,7 +128,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -83,7 +128,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
if len(cur.Task.Chunks) > 0 { if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
result = append(result, cur) result = append(result, cur)
} }
...@@ -91,11 +136,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -91,11 +136,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
s := &Service{} s := &Service{}
s.chunksPerTask = chunksPerTask s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax s.failureMax = failureMax
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{}) s.ready = make(chan struct{})
...@@ -154,7 +199,7 @@ func (s *Service) recover() (bool, error) { ...@@ -154,7 +199,7 @@ func (s *Service) recover() (bool, error) {
// snapshot *must* be called with s.mu being held. // snapshot *must* be called with s.mu being held.
func (s *Service) snapshot() error { func (s *Service) snapshot() error {
// TOOD(helin): etcd request has a size limit, so the snapshot // TODO(helin): etcd request has a size limit, so the snapshot
// size is limited by the max request size. We should either // size is limited by the max request size. We should either
// divide the snapshot into smaller chunks and save under // divide the snapshot into smaller chunks and save under
// different keys, or configure the request size to be big // different keys, or configure the request size to be big
...@@ -209,6 +254,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -209,6 +254,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
} }
count := index.NumChunks() count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunk := Chunk{ chunk := Chunk{
Path: path, Path: path,
...@@ -225,7 +271,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -225,7 +271,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
// //
// SetDataset can be call multiple times. But only the first call will // SetDataset can be call multiple times. But only the first call will
// be honored. // be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error { func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 { if len(globPaths) == 0 {
return errors.New("no dataset specified") return errors.New("no dataset specified")
} }
...@@ -244,19 +290,49 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -244,19 +290,49 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return err return err
} }
s.taskQueues.Todo = partition(chunks, s.chunksPerTask) s.jobTasks = partition(chunks, s.chunksPerTask)
s.taskQueues.Todo = s.jobTasks
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
return err return err
} }
close(s.ready) close(s.ready)
s.initDone = true s.initDone = true
return nil return nil
} }
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check or failed status report.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()
delete(s.taskQueues.Pending, t.Task.Meta.ID)
t.NumFailure++
if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t)
return
}
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() { return func() {
s.mu.Lock() s.mu.Lock()
...@@ -267,30 +343,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -267,30 +343,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return return
} }
if t.Epoch != epoch { s.processFailedTask(t, epoch)
// new epoch, task launched after the
// schedule of this timeout check.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}
log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
} }
} }
...@@ -305,52 +358,45 @@ func (s *Service) logFields() log.Fields { ...@@ -305,52 +358,45 @@ func (s *Service) logFields() log.Fields {
} }
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error { // passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select { select {
case <-s.ready: case <-s.ready:
} }
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if passID < s.currPass {
return ErrPassBefore
}
if passID > s.currPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.taskQueues.Todo) == 0 { if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 { if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 { log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
err := errors.New("all task failed") return ErrAllTaskFailed
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
} }
s.taskQueues.Todo = s.taskQueues.Done log.WithFields(s.logFields()).Warningln("No more available task.")
s.taskQueues.Done = nil return ErrNoMoreAvailable
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
t.Epoch++ t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t s.taskQueues.Pending[t.Task.Meta.ID] = t
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
return err return err
} }
*task = t.Task *task = t.Task
log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID) log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil return nil
} }
...@@ -365,22 +411,24 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -365,22 +411,24 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.taskQueues.Pending[taskID] t, ok := s.taskQueues.Pending[taskID]
if !ok { if !ok {
err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err return nil
} }
// task finished, reset timeout // task finished, reset timeout
t.NumTimeout = 0 t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { // increase master side pass count if all tasks finished
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.") s.currPass++
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Todo = s.jobTasks
s.taskQueues.Done = nil s.taskQueues.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.taskQueues.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass)
} }
err := s.snapshot() err := s.snapshot()
...@@ -389,3 +437,61 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -389,3 +437,61 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
} }
return err return err
} }
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[meta.ID]
if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil
}
s.processFailedTask(t, meta.Epoch)
return nil
}
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "testing" import "testing"
...@@ -30,7 +44,8 @@ func TestPartionIndex(t *testing.T) { ...@@ -30,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 20)
for i := range ts { for i := range ts {
if ts[i].Task.ID != i { // test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i) t.Error(ts[i], i)
} }
} }
......
package master_test
import (
"os"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutils.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
select {
case <-e.Server.ReadyNotify():
t.Log("Server is ready!")
case <-time.After(60 * time.Second):
e.Server.Stop() // trigger a shutdown
t.Fatal("Server took too long to start!")
}
ep := []string{"127.0.0.1:2379"}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
t.Fatal(err)
}
_, err = master.NewService(store, 10, 10, 3)
if err != nil {
t.Fatal(err)
}
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: 3 * time.Second,
})
if err != nil {
t.Fatal(err)
}
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
if err != nil {
t.Fatal(err)
}
if err := cli.Close(); err != nil {
t.Fatal(err)
}
// test master process registry itself into etcd server.
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()
libpaddle_go_optimizer.a
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m) target_link_libraries(paddle_go_optimizer stdc++ m)
# Copy library to the required place.
# See: go/pserver/optimizer.go:
# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
add_custom_command(TARGET paddle_go_optimizer POST_BUILD
COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}"
)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING) if(WITH_TESTING)
# FIXME: this test requires pserver which is not managed by the test # FIXME: this test requires pserver which is not managed by the test
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
...@@ -34,7 +48,6 @@ import ( ...@@ -34,7 +48,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client) var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client var curHandle C.paddle_pserver_client
...@@ -42,10 +55,10 @@ var curHandle C.paddle_pserver_client ...@@ -42,10 +55,10 @@ var curHandle C.paddle_pserver_client
func add(c *client.Client) C.paddle_pserver_client { func add(c *client.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle cli := curHandle
curHandle++ curHandle++
handleMap[client] = c handleMap[cli] = c
return client return cli
} }
func get(client C.paddle_pserver_client) *client.Client { func get(client C.paddle_pserver_client) *client.Client {
...@@ -63,7 +76,7 @@ func remove(client C.paddle_pserver_client) *client.Client { ...@@ -63,7 +76,7 @@ func remove(client C.paddle_pserver_client) *client.Client {
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr { if p == nil {
return nil return nil
} }
...@@ -101,11 +114,11 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli ...@@ -101,11 +114,11 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client { func paddle_new_etcd_pserver_client(etcdEndpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters) // TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcd_endpoints) addr := C.GoString(etcdEndpoints)
etcd_client := client.NewEtcd(addr) etcdClient := client.NewEtcd(addr)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0)) c := client.NewClient(etcdClient, etcdClient.Desired(), selector(selected != 0))
return add(c) return add(c)
} }
...@@ -114,30 +127,36 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) { ...@@ -114,30 +127,36 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client) remove(client)
} }
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { if selected := c.BeginInitParams(); selected {
return 1 return 1
} }
return C.PSERVER_OK return 0
} }
//export paddle_init_param //export paddle_init_param
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int { func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, paramConfig unsafe.Pointer, configLen C.int) C.int {
et := pserver.ElementType(param.element_type) et := pserver.ElementType(param.element_type)
name := C.GoString(param.name) name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
pc := pserver.ParameterWithConfig{ pc := pserver.ParameterWithConfig{
Param: pserver.Parameter{Name: name, ElementType: et, Content: content}, Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
Config: cArrayToSlice(param_config, int(config_len)), Config: cArrayToSlice(paramConfig, int(configLen)),
} }
c := get(client) c := get(client)
err := c.InitParam(pc) err := c.InitParam(pc)
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name) log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Errorln(err)
...@@ -153,7 +172,7 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int { ...@@ -153,7 +172,7 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningln("parameters already initialized, treat paddle_finish_init_params as sucessful.") log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK return C.PSERVER_OK
} }
...@@ -223,12 +242,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -223,12 +242,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
p := ps[i] p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nullPtr { if unsafe.Pointer(param) == nil {
log.Errorln("must pre-allocate parameter.") log.Errorln("must pre-allocate parameter.")
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
if unsafe.Pointer(param.content) != nullPtr { if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) { if int(param.content_len) != len(p.Content) {
log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
return C.PSERVER_ERROR return C.PSERVER_ERROR
...@@ -243,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -243,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return C.PSERVER_OK return C.PSERVER_OK
} }
//export paddle_save_model
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
if err != nil {
log.Errorln(err)
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
func main() {} // Required but ignored func main() {} // Required but ignored
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer)
add_style_check_target(test_cclient test_cclient.c) add_style_check_target(test_cclient test_cclient.c)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
...@@ -97,9 +111,5 @@ retry: ...@@ -97,9 +111,5 @@ retry:
getParams(c); getParams(c);
} }
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0; return 0;
} }
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master
import os
import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoints = "http://" + etcd_ip + ":2379"
print "etcd endpoints: ", etcd_endpoints
def main(): def main():
...@@ -9,30 +17,33 @@ def main(): ...@@ -9,30 +17,33 @@ def main():
# network config # network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x, y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'), param_attr=paddle.attr.Param(
name='w', learning_rate=1e-3),
size=1, size=1,
act=paddle.activation.Linear(), act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b')) bias_attr=paddle.attr.Param(
name='b', learning_rate=1e-3))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y) cost = paddle.layer.mse_cost(input=y_predict, label=y)
# create parameters # create parameters
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
is_local=False, is_local=False,
pserver_spec="localhost:3000") pserver_spec=etcd_endpoints,
use_etcd=True)
# event_handler to print training and testing info # event_handler to print training and testing info
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % ( print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, event.cost)
...@@ -47,10 +58,14 @@ def main(): ...@@ -47,10 +58,14 @@ def main():
print "Test %d, %.2f" % (event.pass_id, result.cost) print "Test %d, %.2f" % (event.pass_id, result.cost)
# training # training
# NOTE: use uci_housing.train() as reader for non-paddlecloud training
trainer.train( trainer.train(
reader=paddle.batch( reader=paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
uci_housing.train(), buf_size=500), cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2), batch_size=2),
feeding={'x': 0, feeding={'x': 0,
'y': 1}, 'y': 1},
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
...@@ -205,35 +219,9 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) { ...@@ -205,35 +219,9 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return ps, nil return ps, nil
} }
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
err := p.Call("Service.Save", path, nil)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 { func strHash(s string) uint32 {
h := fnv.New32a() h := fnv.New32a()
h.Write([]byte(s)) _, _ = h.Write([]byte(s))
return h.Sum32() return h.Sum32()
} }
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client_test package client_test
import ( import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -42,7 +58,8 @@ func initClient() [numPserver]int { ...@@ -42,7 +58,8 @@ func initClient() [numPserver]int {
ports[i] = p ports[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -76,15 +93,33 @@ func initEtcdClient() { ...@@ -76,15 +93,33 @@ func initEtcdClient() {
log.Errorf("err %v", err) log.Errorf("err %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired) _, err = client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath) if err != nil {
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver)) panic(err)
}
_, err = client.Delete(ctx, pserver.PsPath)
if err != nil {
panic(err)
}
_, err = client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
if err != nil {
panic(err)
}
ports := initClient() ports := initClient()
for i := 0; i < numPserver; i++ { for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i])) _, err = client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
if err != nil {
panic(err)
}
} }
cancel() cancel()
client.Close() err = client.Close()
if err != nil {
panic(err)
}
} }
type selector bool type selector bool
...@@ -99,27 +134,34 @@ func (l lister) List() []client.Server { ...@@ -99,27 +134,34 @@ func (l lister) List() []client.Server {
return l return l
} }
func ClientTest(t *testing.T, c *client.Client) { func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams() selected := c.BeginInitParams()
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
const numParameter = 100 const numParameter = 1000
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb") config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
var wg sync.WaitGroup
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
var p pserver.Parameter wg.Add(1)
p.Name = "p_" + strconv.Itoa(i) go func(i int) {
p.ElementType = pserver.Float32 var p pserver.Parameter
p.Content = make([]byte, (i+1)*100) p.Name = "p_" + strconv.Itoa(i)
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}) p.ElementType = pserver.Float32
if err != nil { p.Content = make([]byte, (i+1)*100)
t.Fatal(err) err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
} if err != nil {
t.Fatal(err)
}
wg.Done()
}(i)
} }
wg.Wait()
err = c.FinishInitParams() err = c.FinishInitParams()
if err != nil { if err != nil {
...@@ -127,7 +169,7 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -127,7 +169,7 @@ func ClientTest(t *testing.T, c *client.Client) {
} }
var grads []pserver.Gradient var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ { for i := 0; i < numParameter; i++ {
var g pserver.Gradient var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i) g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32 g.ElementType = pserver.Float32
...@@ -135,9 +177,31 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -135,9 +177,31 @@ func ClientTest(t *testing.T, c *client.Client) {
grads = append(grads, g) grads = append(grads, g)
} }
err = c.SendGrads(grads) const paramPerGroup = 10
if err != nil { const numGroups = numParameter / paramPerGroup
t.Fatal(err)
// shuffle send grads order
for i := range grads {
j := rand.Intn(i + 1)
grads[i], grads[j] = grads[j], grads[i]
}
for i := 0; i < numGroups; i++ {
var gs []pserver.Gradient
if i == numGroups-1 {
gs = grads[i*paramPerGroup:]
} else {
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(gs []pserver.Gradient) {
err := c.SendGrads(gs)
if err != nil {
t.Fatal(err)
}
wg.Done()
}(gs)
} }
names := make([]string, numParameter) names := make([]string, numParameter)
...@@ -145,20 +209,35 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -145,20 +209,35 @@ func ClientTest(t *testing.T, c *client.Client) {
names[i] = "p_" + strconv.Itoa(i) names[i] = "p_" + strconv.Itoa(i)
} }
params, err := c.GetParams(names) for i := 0; i < numGroups; i++ {
if err != nil { var ns []string
t.Fatal(err) if i == numGroups-1 {
} ns = names[i*paramPerGroup:]
} else {
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
}
if len(names) != len(params) { wg.Add(1)
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) go func(ns []string) {
} params, err := c.GetParams(ns)
if err != nil {
t.Fatal(err)
}
for i := range params { if len(ns) != len(params) {
if names[i] != params[i].Name { t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name) }
}
for i := range params {
if ns[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
}
}
wg.Done()
}(ns)
} }
wg.Wait()
} }
func TestNativeClient(t *testing.T) { func TestNativeClient(t *testing.T) {
...@@ -168,13 +247,14 @@ func TestNativeClient(t *testing.T) { ...@@ -168,13 +247,14 @@ func TestNativeClient(t *testing.T) {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])} servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
} }
c1 := client.NewClient(lister(servers), len(servers), selector(true)) c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1) testClient(t, c1)
} }
// TODO: tmperary disable etcdClient test for dependency of etcd) // EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
func EtcdClient(t *testing.T) { func EtcdClient(t *testing.T) {
initEtcdClient() initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints) etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true)) c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2) testClient(t, c2)
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
...@@ -12,7 +26,7 @@ import ( ...@@ -12,7 +26,7 @@ import (
) )
const ( const (
DefaultEtcdTimeout time.Duration = 5 * time.Second defaultEtcdTimeout time.Duration = 5 * time.Second
) )
// EtcdClient is used by pserver client that is a part of trainer process. // EtcdClient is used by pserver client that is a part of trainer process.
...@@ -47,7 +61,7 @@ func (p *EtcdClient) Desired() int { ...@@ -47,7 +61,7 @@ func (p *EtcdClient) Desired() int {
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %s invalid %v", psDesired, err) log.Errorf("psDesired %d invalid %v", psDesired, err)
time.Sleep(p.timeout) time.Sleep(p.timeout)
continue continue
} }
...@@ -66,10 +80,10 @@ func (p *EtcdClient) List() []Server { ...@@ -66,10 +80,10 @@ func (p *EtcdClient) List() []Server {
for { for {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey) resp, err := p.client.Get(ctx, psKey)
cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout) time.Sleep(p.timeout)
...@@ -106,11 +120,11 @@ func NewEtcd(endpoints string) *EtcdClient { ...@@ -106,11 +120,11 @@ func NewEtcd(endpoints string) *EtcdClient {
for { for {
cli, err = clientv3.New(clientv3.Config{ cli, err = clientv3.New(clientv3.Config{
Endpoints: ep, Endpoints: ep,
DialTimeout: DefaultEtcdTimeout, DialTimeout: defaultEtcdTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("Init etcd connection failed: %v", err) log.Errorf("Init etcd connection failed: %v", err)
time.Sleep(DefaultEtcdTimeout) time.Sleep(defaultEtcdTimeout)
continue continue
} }
break break
...@@ -118,7 +132,7 @@ func NewEtcd(endpoints string) *EtcdClient { ...@@ -118,7 +132,7 @@ func NewEtcd(endpoints string) *EtcdClient {
log.Infof("Connected to etcd: %s\n", endpoints) log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{ client := &EtcdClient{
client: cli, client: cli,
timeout: DefaultEtcdTimeout, timeout: defaultEtcdTimeout,
endpoints: ep, endpoints: ep,
} }
return client return client
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
...@@ -16,18 +30,23 @@ import ( ...@@ -16,18 +30,23 @@ import (
const ( const (
// PsDesired is etcd path for store desired pserver count // PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired" PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr // PsPath is the base dir for pserver to store their addr
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
retryTimeout = 5 * time.Second
) )
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination. // tolerance, service registry and coordination.
type EtcdClient struct { type EtcdClient struct {
numPservers int numPservers int
etcdEndpoints string endpoints string
etcdClient *clientv3.Client client *clientv3.Client
// etcdTimeout is also used as retry intervals. sess *concurrency.Session
etcdTimeout time.Duration dialTimeout time.Duration
ttlSec int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string externalIP string
// desired number of pservers in the job. // desired number of pservers in the job.
...@@ -36,19 +55,19 @@ type EtcdClient struct { ...@@ -36,19 +55,19 @@ type EtcdClient struct {
} }
// NewEtcdClient creates an EtcdClient // NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
return &EtcdClient{ return &EtcdClient{
etcdTimeout: timeout, dialTimeout: dialtimeout,
numPservers: numPservers, ttlSec: ttlSec,
etcdEndpoints: endpoints, numPservers: numPservers,
endpoints: endpoints,
} }
} }
// Register registers the pserver on etcd // Register registers the pserver on etcd
// //
// Register returns the index of the current pserver. // Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) { func (e *EtcdClient) Register(port int) (int, error) {
var err error var err error
e.externalIP, err = networkhelper.GetExternalIP() e.externalIP, err = networkhelper.GetExternalIP()
if err != nil { if err != nil {
...@@ -56,19 +75,26 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -56,19 +75,26 @@ func (e *EtcdClient) Register() (int, error) {
} }
// initialize connection to etcd. // initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",") ep := strings.Split(e.endpoints, ",")
for { for {
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: ep, Endpoints: ep,
DialTimeout: e.etcdTimeout, DialTimeout: e.dialTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("connect to etcd error: %v", err) log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.etcdClient = cli e.client = cli
log.Debugf("inited client to %s", e.etcdEndpoints) sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil {
log.Errorf("create etcd session error: %v", err)
time.Sleep(retryTimeout)
continue
}
e.sess = sess
log.Debugf("inited client to %s", e.endpoints)
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write // init /ps_desired using transaction, for multiple pservers may want to write
...@@ -79,7 +105,7 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -79,7 +105,7 @@ func (e *EtcdClient) Register() (int, error) {
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
...@@ -90,18 +116,18 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -90,18 +116,18 @@ func (e *EtcdClient) Register() (int, error) {
// wait and set s.desired init value // wait and set s.desired init value
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired) resp, err := e.client.Get(ctx, PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err) log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
if len(resp.Kvs) != 0 { if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err) log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change // NOTE: wait util ps_desired value change
continue continue
} }
...@@ -114,11 +140,11 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -114,11 +140,11 @@ func (e *EtcdClient) Register() (int, error) {
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error var err error
pserverIdx, err = e.registerPserverEtcd(ctx) pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
...@@ -128,19 +154,19 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -128,19 +154,19 @@ func (e *EtcdClient) Register() (int, error) {
} }
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired) dsStr := c.Get(PsDesired)
if dsStr == "" { if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers)) c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
} }
return nil return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
} }
// registerPserverEtcd registers pserver node on etcd using transaction. // registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { _, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
...@@ -149,35 +175,20 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -149,35 +175,20 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" { if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info // find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID)) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished") log.Debug("register finished")
idx = i idx = i
registered = true registered = true
break break
} }
} }
if registered == true { if registered {
return nil return nil
} }
return errors.New("not registerd, may due to already have enough pservers") return errors.New("not registered, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil { if err != nil {
...@@ -186,3 +197,47 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -186,3 +197,47 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil return idx, nil
} }
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.client.Get(ctx, key)
cancel()
if err != nil {
return []byte{}, err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return []byte{}, nil
}
v := kvs[0].Value
return v, nil
}
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
cancel()
return err
}
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
var err error
if e.sess != nil {
err = e.sess.Close()
}
if e.client != nil {
newErr := e.client.Close()
if newErr != nil {
if err != nil {
log.Errorln(newErr)
} else {
err = newErr
}
}
}
return err
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #cgo LDFLAGS: ${SRCDIR}/../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #include "paddle/optimizer/optimizer.h" // #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>
...@@ -15,15 +28,14 @@ import ( ...@@ -15,15 +28,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct { type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
elementType ElementType elementType ElementType
contentLen int
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr { if p == nil {
return nil return nil
} }
...@@ -35,22 +47,31 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -35,22 +47,31 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len] return (*[1 << 30]byte)(p)[:len:len]
} }
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{} o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType o.elementType = paramWithConfigs.Param.ElementType
o.contentLen = len(paramWithConfigs.Param.Content)
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State
paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": len(p.Content), "ParamSize": paramBufferSize,
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:") }).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content))) cbuffer = C.malloc(paramBufferSize)
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), paramBufferSize)
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
(*C.char)(nullPtr), 0)
return o return o
} }
...@@ -60,12 +81,22 @@ func (o *optimizer) GetWeights() []byte { ...@@ -60,12 +81,22 @@ func (o *optimizer) GetWeights() []byte {
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float) return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
} }
func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
}
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {
if o.elementType != g.ElementType { if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float) if o.contentLen != len(g.Content) {
return fmt.Errorf("Name: %s, parameter and gradient does not have same content len, parameter: %d, gradient: %d", g.Name, o.contentLen, len(g.Content))
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
if r != 0 { if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r) return fmt.Errorf("optimizer update returned error code: %d", r)
} }
...@@ -73,8 +104,8 @@ func (o *optimizer) UpdateParameter(g Gradient) error { ...@@ -73,8 +104,8 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
} }
func (o *optimizer) Cleanup() { func (o *optimizer) Cleanup() {
if unsafe.Pointer(o.opt) != nullPtr { if unsafe.Pointer(o.opt) != nil {
C.paddle_release_optimizer(o.opt) C.paddle_release_optimizer(o.opt)
o.opt = (*C.struct_paddle_optimizer)(nullPtr) o.opt = (*C.struct_paddle_optimizer)(nil)
} }
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
...@@ -19,6 +33,6 @@ func TestOptimizerCreateRelease(t *testing.T) { ...@@ -19,6 +33,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param: p, Param: p,
Config: config, Config: config,
} }
o := newOptimizer(param) o := newOptimizer(param, nil)
o.Cleanup() o.Cleanup()
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
// RPC error message.
const ( const (
// AlreadyInitialized is true if pserver is initialized AlreadyInitialized = "pserver already initialized"
AlreadyInitialized = "pserver already initialized" Uninitialized = "pserver not fully initialized"
// Uninitialized is true if pserver not fully initialized CheckpointMD5Failed = "checkpoint file MD5 validation failed"
Uninitialized = "pserver not fully initialized"
) )
// Supported element types // Supported element types.
const ( const (
Int32 ElementType = iota Int32 ElementType = iota
UInt32 UInt32
...@@ -39,31 +70,101 @@ type ParameterWithConfig struct { ...@@ -39,31 +70,101 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
} }
// checkpointMeta saves checkpoint metadata
type checkpointMeta struct {
UUID string `json:"uuid"`
MD5 string `json:"md5"`
Timestamp int64 `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
}
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
ParameterWithConfig
State []byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil {
return nil, err
}
if len(v) == 0 {
return nil, ErrCheckpointNotFound
}
var cpMeta checkpointMeta
if err = json.Unmarshal(v, &cpMeta); err != nil {
return nil, err
}
fn := filepath.Join(cpPath, cpMeta.UUID)
if _, err = os.Stat(fn); os.IsNotExist(err) {
return nil, err
}
content, err := ioutil.ReadFile(fn)
if err != nil {
return nil, err
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
return nil, errors.New(CheckpointMD5Failed)
}
mu sync.Mutex dec := gob.NewDecoder(bytes.NewReader(content))
optMap map[string]*optimizer cp := Checkpoint{}
if err = dec.Decode(cp); err != nil {
return nil, err
}
return cp, nil
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int) (*Service, error) { func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: interval,
checkpointPath: path,
client: client,
} }
s.optMap = make(map[string]*optimizer) s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
if cp != nil {
for _, item := range cp {
p := ParameterWithConfig{
Param: item.Param,
Config: item.Config,
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
}
return s, nil return s, nil
} }
// InitParam initializes a parameter. // InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
...@@ -78,13 +179,13 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -78,13 +179,13 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is // TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil return nil
} }
// FinishInitParams tells the parameter server that the parameter // FinishInitParams tells the parameter server that the parameter
// initialization has finished. // initialization has finished.
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { func (s *Service) FinishInitParams(_ int, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
...@@ -97,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { ...@@ -97,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter // SendGrad sends gradient to parameter servers for parameter
// optimization. // optimization.
func (s *Service) SendGrad(g Gradient, dummy *int) error { func (s *Service) SendGrad(g Gradient, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
...@@ -132,17 +233,89 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -132,17 +233,89 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// learning optimization methods are stochastic in // learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately // nature. This race condition is allowed deliberately
// to save the program from making a copy of the // to save the program from making a copy of the
// paramter content. // parameter content.
parameter.Name = name parameter.Name = name
parameter.ElementType = opt.elementType parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights() parameter.Content = opt.GetWeights()
return nil return nil
} }
// Save tells the parameter server to save parameters. // pserver save checkpoint
func (s *Service) Save(path string, dummy *int) error { func (s *Service) doCheckpoint() (err error) {
<-s.initialized <-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
// TODO cp := make([]parameterCheckpoint, len(s.optMap))
return nil index := 0
for name, opt := range s.optMap {
var pc parameterCheckpoint
pc.Param.Name = name
pc.Param.ElementType = opt.elementType
pc.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err = encoder.Encode(cp)
if err != nil {
return
}
cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().UnixNano()
h := md5.New()
cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, err := json.Marshal(cpMeta)
if err != nil {
return
}
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
if err != nil {
return
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
if err != nil {
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
} else {
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
}
f, err := os.Create(cpMeta.UUID)
if err != nil {
return
}
defer func() {
closeErr := f.Close()
if closeErr != nil {
if err != nil {
log.Errorln(closeErr)
} else {
// Set closeErr as return value.
err = closeErr
}
}
}()
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
if err != nil {
return
}
err = writer.Flush()
if err != nil {
return
}
return
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver_test package pserver_test
import ( import (
...@@ -15,7 +29,8 @@ const ( ...@@ -15,7 +29,8 @@ const (
) )
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -30,7 +45,7 @@ func TestServiceFull(t *testing.T) { ...@@ -30,7 +45,7 @@ func TestServiceFull(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var p1 pserver.Parameter var p1 pserver.Parameter
...@@ -39,40 +54,40 @@ func TestServiceFull(t *testing.T) { ...@@ -39,40 +54,40 @@ func TestServiceFull(t *testing.T) {
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var param pserver.Parameter var param pserver.Parameter
err = s.GetParam("param_b", &param) err = s.GetParam("param_b", &param)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
if !reflect.DeepEqual(param, p1) { if !reflect.DeepEqual(param, p1) {
t.FailNow() t.Fatal("not equal:", param, p1)
} }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil) err = s.SendGrad(g1, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.SendGrad(g2, nil) err = s.SendGrad(g2, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var param1 pserver.Parameter var param1 pserver.Parameter
err = s.GetParam("param_a", &param1) err = s.GetParam("param_a", &param1)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
// don't compare content, since it's already changed by // don't compare content, since it's already changed by
...@@ -81,36 +96,39 @@ func TestServiceFull(t *testing.T) { ...@@ -81,36 +96,39 @@ func TestServiceFull(t *testing.T) {
p.Content = nil p.Content = nil
if !reflect.DeepEqual(param1, p) { if !reflect.DeepEqual(param1, p) {
t.FailNow() t.Fatal("not equal:", param1, p)
} }
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err.Error() != pserver.AlreadyInitialized { if err.Error() != pserver.AlreadyInitialized {
t.FailNow() t.Fatal(err)
} }
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.Fatal(err)
} }
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -128,16 +146,6 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -128,16 +146,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{} ch <- struct{}{}
}() }()
wg.Add(1)
go func() {
err := s.Save("", nil)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
select { select {
...@@ -160,13 +168,17 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -160,13 +168,17 @@ func TestBlockUntilInitialized(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
wg.Wait() wg.Wait()
} }
func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(network_helper_test)
endif()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import "testing" import "testing"
......
...@@ -8,13 +8,14 @@ add_subdirectory(gserver) ...@@ -8,13 +8,14 @@ add_subdirectory(gserver)
add_subdirectory(pserver) add_subdirectory(pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
add_subdirectory(optimizer)
add_subdirectory(string) add_subdirectory(string)
if(Boost_FOUND) if(Boost_FOUND)
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif() endif()
if(WITH_C_API) if(WITH_C_API)
......
...@@ -64,11 +64,7 @@ ModelConfig* TrainerConfig::getModelConfig() const { ...@@ -64,11 +64,7 @@ ModelConfig* TrainerConfig::getModelConfig() const {
ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {} ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {}
ParameterConfig::~ParameterConfig() { ParameterConfig::~ParameterConfig() { delete m; }
if (m) {
delete m;
}
}
ParameterConfig* ParameterConfig::createParameterConfigFromParameterSharedPtr( ParameterConfig* ParameterConfig::createParameterConfigFromParameterSharedPtr(
void* ptr) { void* ptr) {
...@@ -98,11 +94,7 @@ void* ParameterConfig::getRawPtr() { return m->getConfigPtr(); } ...@@ -98,11 +94,7 @@ void* ParameterConfig::getRawPtr() { return m->getConfigPtr(); }
OptimizationConfig::OptimizationConfig() : m(new OptimizationConfigPrivate()) {} OptimizationConfig::OptimizationConfig() : m(new OptimizationConfigPrivate()) {}
OptimizationConfig::~OptimizationConfig() { OptimizationConfig::~OptimizationConfig() { delete m; }
if (m) {
delete m;
}
}
std::string OptimizationConfig::toProtoString() { std::string OptimizationConfig::toProtoString() {
return m->getConfig().SerializeAsString(); return m->getConfig().SerializeAsString();
......
...@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const { ...@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double Evaluator::getValue(const std::string name) const { double Evaluator::getValue(const std::string name) const {
paddle::Error err; paddle::Error err;
double v = m->rawPtr->getValue(name, &err); double v = m->rawPtr->getValue(name, &err);
if (err) { if (!err.isOK()) {
throw std::runtime_error(err.msg()); throw std::runtime_error(err.msg());
} }
return v; return v;
......
...@@ -843,7 +843,8 @@ public: ...@@ -843,7 +843,8 @@ public:
bool useSparseUpdater); bool useSparseUpdater);
static ParameterUpdater* createNewRemoteUpdater( static ParameterUpdater* createNewRemoteUpdater(
OptimizationConfig* config, OptimizationConfig* config,
const std::string pserverSpec) throw(UnsupportError); const std::string pserverSpec,
const bool useEtcd) throw(UnsupportError);
~ParameterUpdater(); ~ParameterUpdater();
/** /**
......
...@@ -53,11 +53,7 @@ struct ParameterTraverseCallbackPrivate { ...@@ -53,11 +53,7 @@ struct ParameterTraverseCallbackPrivate {
ParameterOptimizer::ParameterOptimizer() : m(new ParameterOptimizerPrivate()) {} ParameterOptimizer::ParameterOptimizer() : m(new ParameterOptimizerPrivate()) {}
ParameterOptimizer::~ParameterOptimizer() { ParameterOptimizer::~ParameterOptimizer() { delete m; }
if (m) {
delete m;
}
}
ParameterOptimizer* ParameterOptimizer::create(OptimizationConfig* config) { ParameterOptimizer* ParameterOptimizer::create(OptimizationConfig* config) {
CHECK(config != nullptr); CHECK(config != nullptr);
...@@ -104,11 +100,7 @@ std::vector<int> ParameterOptimizer::getParameterTypes() const { ...@@ -104,11 +100,7 @@ std::vector<int> ParameterOptimizer::getParameterTypes() const {
ParameterTraverseCallback::ParameterTraverseCallback() ParameterTraverseCallback::ParameterTraverseCallback()
: m(new ParameterTraverseCallbackPrivate()) {} : m(new ParameterTraverseCallbackPrivate()) {}
ParameterTraverseCallback::~ParameterTraverseCallback() { ParameterTraverseCallback::~ParameterTraverseCallback() { delete m; }
if (m) {
delete m;
}
}
void ParameterTraverseCallback::apply(const std::vector<Vector*>& vecs, void ParameterTraverseCallback::apply(const std::vector<Vector*>& vecs,
const ParameterConfig& conf, const ParameterConfig& conf,
......
...@@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ...@@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
ParameterUpdater *ParameterUpdater::createNewRemoteUpdater( ParameterUpdater *ParameterUpdater::createNewRemoteUpdater(
OptimizationConfig *config, OptimizationConfig *config,
const std::string pserverSpec) throw(UnsupportError) { const std::string pserverSpec,
const bool useEtcd) throw(UnsupportError) {
#ifndef PADDLE_WITHOUT_GOLANG #ifndef PADDLE_WITHOUT_GOLANG
auto updater = new ParameterUpdater(); auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::NewRemoteParameterUpdater( updater->m->updater.reset(new paddle::NewRemoteParameterUpdater(
config->m->getConfig(), pserverSpec)); config->m->getConfig(), pserverSpec, useEtcd));
return updater; return updater;
#else #else
throw UnsupportError(); throw UnsupportError();
......
...@@ -171,11 +171,7 @@ struct VectorPrivate { ...@@ -171,11 +171,7 @@ struct VectorPrivate {
Vector::Vector() : m(new VectorPrivate()) {} Vector::Vector() : m(new VectorPrivate()) {}
Vector::~Vector() { Vector::~Vector() { delete m; }
if (m) {
delete m;
}
}
Vector* Vector::createZero(size_t sz, bool useGpu) { Vector* Vector::createZero(size_t sz, bool useGpu) {
auto retVec = new Vector(); auto retVec = new Vector();
......
...@@ -17,73 +17,6 @@ limitations under the License. */ ...@@ -17,73 +17,6 @@ limitations under the License. */
#include "hl_base.h" #include "hl_base.h"
/**
* @brief Shrink column to feature.
*
* @param[in] dataCol expand data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] blockH filter height.
* @param[in] blockW filter width.
* @param[in] strideH stride height.
* @param[in] strideW stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[in] outputH output height.
* @param[in] outputW output width.
* @param[out] dataIm output image data.
* @param[in] alpha
* @param[in] beta
*/
extern void hl_shrink_col2feature(const real* dataCol,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataIm,
real alpha = 1.0f,
real beta = 0.0f);
/**
* @brief Expand feature to column.
*
* @param[in] dataIm input image data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] blockH filter height.
* @param[in] blockW filter width.
* @param[in] strideH stride height.
* @param[in] strideW stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[in] outputH output height.
* @param[in] outputW output width.
* @param[out] dataCol expand data.
*
*/
extern void hl_expand_feature2col(const real* dataIm,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataCol);
/** /**
* @brief Maximum pool forward. * @brief Maximum pool forward.
* *
......
...@@ -17,36 +17,6 @@ limitations under the License. */ ...@@ -17,36 +17,6 @@ limitations under the License. */
#include "hl_cnn.h" #include "hl_cnn.h"
inline void hl_shrink_col2feature(const real* dataCol,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataIm,
real alpha,
real beta) {}
inline void hl_expand_feature2col(const real* dataIm,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataCol) {}
inline void hl_maxpool_forward(const int frameCnt, inline void hl_maxpool_forward(const int frameCnt,
const real* inputData, const real* inputData,
const int channels, const int channels,
......
...@@ -18,134 +18,6 @@ limitations under the License. */ ...@@ -18,134 +18,6 @@ limitations under the License. */
#include "hl_cnn.h" #include "hl_cnn.h"
#include "hl_device_functions.cuh" #include "hl_device_functions.cuh"
__global__ void KeFeature2col(size_t n, size_t height, const real* data_im,
size_t blockH, size_t blockW, size_t width,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t height_col, size_t width_col,
real* data_col) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
size_t w_out = index % width_col;
index /= width_col;
size_t h_out = index % height_col;
size_t channel_in = index / height_col;
size_t channel_out = channel_in * blockH * blockW;
size_t h_in = h_out * strideH;
size_t w_in = w_out * strideW;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (size_t i = 0; i < blockH; ++i) {
for (size_t j = 0; j < blockW; ++j) {
int rIdx = int(h_in+i);
int cIdx = int(w_in+j);
if ((rIdx-(int)paddingH) >= (int)height ||
(rIdx-(int)paddingH) < 0 ||
(cIdx-(int)paddingW) >= (int)width ||
(cIdx-(int)paddingW) < 0) {
*data_col = 0;
} else {
rIdx = rIdx + channel_in*height - paddingH;
cIdx = cIdx - paddingW;
*data_col = data_im[rIdx* width + cIdx];
}
data_col += height_col * width_col;
}
}
}
}
void hl_expand_feature2col(const real* dataIm, size_t channels,
size_t height, size_t width,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t outputH, size_t outputW,
real* dataCol) {
size_t numKernels = channels * outputH * outputW;
size_t blocks = (numKernels + 1024 -1) / 1024;
size_t blockX = 512;
size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
KeFeature2col<<< grid, threads, 0, STREAM_DEFAULT >>>
(numKernels, height, dataIm, blockH, blockW, width,
strideH, strideW, paddingH, paddingW,
outputH, outputW, dataCol);
CHECK_SYNC("hl_expand_feature2col failed");
}
__global__ void KeCol2Feature(size_t n, const real* data_col, size_t height,
size_t width, size_t channels,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t height_col, size_t width_col,
real* data_im, real alpha, real beta) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
real val = 0;
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width-2 * paddingW) &&
(h - (int)paddingH) >= 0 &&
(h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output
int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
int w_col_end =
min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH* blockW) + \
(h - h_col * (int)strideH) * (int)blockW +
(w - w_col * (int)strideW);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
h -= paddingH;
w -= paddingW;
real tD = data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
h*(width-2*paddingW) + w];
data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
h*(width-2*paddingW) + w] = alpha * val + beta*tD;
}
}
}
void hl_shrink_col2feature(const real * dataCol, size_t channels,
size_t height, size_t width,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t outputH, size_t outputW,
real* dataIm, real alpha, real beta) {
size_t numKernels = channels * (height + 2*paddingH) * (width + 2*paddingW);
size_t blocks = (numKernels + 1024 -1) / 1024;
size_t blockX = 512;
size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
KeCol2Feature<<< grid, threads, 0, STREAM_DEFAULT >>>
(numKernels, dataCol, height + 2*paddingH, width + 2*paddingW,
channels, blockH, blockW, strideH, strideW, paddingH, paddingW,
outputH, outputW, dataIm, alpha, beta);
CHECK_SYNC("hl_shrink_col2feature failed");
}
__global__ void KeMaxPoolForward(const int nthreads, const real* inputData, __global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
const int channels, const int height, const int channels, const int height,
const int width, const int width,
......
...@@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, ...@@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
real alpha = 1.0f; real alpha = 1.0f;
real beta = 1.0f; real beta = 1.0f;
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size;
if (batch_size > 1024 && g_cudnn_lib_version < 6000) {
LOG(INFO) << " To process current batch data with size " << batch_size
<< " (>1024), cudnnBatchNorm requires cuDNN version >= 6000."
<< " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED,"
<< " just recompile PaddlePaddle with cuDNN >= 6000, replacing"
<< " current version " << g_cudnn_lib_version;
}
CHECK_CUDNN( CHECK_CUDNN(
dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle,
mode, mode,
......
...@@ -269,8 +269,7 @@ void hl_sequence2batch_copy_padding(real* batch, ...@@ -269,8 +269,7 @@ void hl_sequence2batch_copy_padding(real* batch,
int blockDimY = CUDA_BLOCK_SIZE / blockDimX; int blockDimY = CUDA_BLOCK_SIZE / blockDimX;
dim3 threads(blockDimX, blockDimY); dim3 threads(blockDimX, blockDimY);
int gridDimX = (maxSequenceLength * blockDimX + CUDA_BLOCK_SIZE - 1) / int gridDimX = (maxSequenceLength + blockDimY - 1) / blockDimY;
CUDA_BLOCK_SIZE;
int gridDimY = numSequences; int gridDimY = numSequences;
dim3 grid(gridDimX, gridDimY); dim3 grid(gridDimX, gridDimY);
...@@ -330,7 +329,7 @@ __global__ void KeSequenceAvgForward(real* dst, ...@@ -330,7 +329,7 @@ __global__ void KeSequenceAvgForward(real* dst,
} }
sum = mode == 1 ? sum : sum = mode == 1 ? sum :
(mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength)); (mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength));
dst[gid] = sum; dst[gid] += sum;
} }
} }
......
# ddim lib # ddim lib
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(tensor_test SRCS tensor_test.cc DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)
cc_test(enforce_test SRCS enforce_test.cc) cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(attr_type SRCS attr_type.proto) proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type) proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net)
cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS backward)
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/framework/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,6 +42,35 @@ class DefaultValueSetter { ...@@ -41,6 +42,35 @@ class DefaultValueSetter {
T default_value_; T default_value_;
}; };
template <typename T>
class EnumInContainer {
public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(),
"Value %s is not in enum container %s", val,
ContainerDebugString());
}
private:
std::string ContainerDebugString() const {
std::ostringstream sout;
sout << "[";
size_t cnt = 0;
for (auto& v : container_) {
sout << v;
++cnt;
if (cnt != container_.size()) {
sout << " ,";
}
}
sout << "]";
return sout.str();
}
std::unordered_set<T> container_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
...@@ -50,6 +80,11 @@ class TypedAttrChecker { ...@@ -50,6 +80,11 @@ class TypedAttrChecker {
public: public:
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
TypedAttrChecker& InEnum(const std::unordered_set<T>& range) {
value_checkers_.push_back(EnumInContainer<T>(range));
return *this;
}
TypedAttrChecker& LargerThan(const T& lower_bound) { TypedAttrChecker& LargerThan(const T& lower_bound) {
value_checkers_.push_back(LargerThanChecker<T>(lower_bound)); value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
return *this; return *this;
......
此差异已折叠。
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_set>
#include "operator.h"
namespace paddle {
namespace framework {
// Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
} // namespace paddle
## Operator/expression 's Backward
### Motivation
In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation lineage, the operator/ expression's Backward feature will generate the backward pass respect to forward pass.
### Implement : gradient operator registry
| | forward operator | backward operator |
| ---------------------- | ---------------- | -------------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
Inputs/Outputs means the input/output of the operator, InputGradients/OutputGradients is the gradient respect to forward opeartor. Forward operator and Backward operator are isomorphic, save their corresponding needs into member attribute.
We use a global hash map record the gradient operators available, follow the philosophy of minimum core, make operator pluggable unit. Each gradient is an operator and it needs to regist itself.
grad_op_builder(fengjiayi)
### Implement : Backward network
given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
1. bla bla bla (yuyang)
2. NetOp
when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively and ensure them done. During the process, we need to collect the `OutputGradients` name.
We share variable in the same scope, as a result, duplicate operator `OutputGradients` will overwirte then duplicate variable.
![./images/duplicate_op]()
Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator instead.
![./images/duplicate_op2]()
​ Then collect the sub graph OutputGradients/InputGradients as the NetOp's and return it.
此差异已折叠。
此差异已折叠。
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <initializer_list> #include <initializer_list>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "paddle/framework/dim.h" #include "paddle/framework/dim.h"
#include "paddle/platform/enforce.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,7 +42,9 @@ struct DDim { ...@@ -27,7 +42,9 @@ struct DDim {
DDim() : var(Dim<1>()) {} DDim() : var(Dim<1>()) {}
template <int D> template <int D>
DDim(const Dim<D>& in) : var(in) {} explicit DDim(const Dim<D>& in) : var(in) {}
/*implicit*/ DDim(std::initializer_list<int> init_list);
template <int D> template <int D>
DDim& operator=(const Dim<D>& in) { DDim& operator=(const Dim<D>& in) {
...@@ -57,6 +74,8 @@ struct DDim { ...@@ -57,6 +74,8 @@ struct DDim {
DDim operator+(DDim d) const; DDim operator+(DDim d) const;
DDim operator*(DDim d) const; DDim operator*(DDim d) const;
ssize_t size() const;
}; };
/** /**
...@@ -81,6 +100,15 @@ std::vector<int> vectorize(const DDim& ddim); ...@@ -81,6 +100,15 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim); ssize_t product(const DDim& ddim);
/**
* \brief Slice a ddim
*
* Slice dim with [begin, end).
* e.g. DDim d = make_ddim({1,2,3,4,5});
* slice_ddim(d, 1, 3); ====> {2,3}
*/
DDim slice_ddim(const DDim& dim, int begin, int end);
/** /**
* \brief What is the length of this dimension? * \brief What is the length of this dimension?
* *
......
...@@ -49,9 +49,30 @@ TEST(DDim, Equality) { ...@@ -49,9 +49,30 @@ TEST(DDim, Equality) {
// arity of a DDim // arity of a DDim
EXPECT_EQ(paddle::framework::arity(ddim), 3); EXPECT_EQ(paddle::framework::arity(ddim), 3);
EXPECT_EQ(ddim.size(), 3);
// product of a DDim // product of a DDim
EXPECT_EQ(paddle::framework::product(vddim), 45); EXPECT_EQ(paddle::framework::product(vddim), 45);
EXPECT_EQ(
paddle::framework::product(paddle::framework::make_ddim({3, 2, 5, 3})),
90);
// slice a DDim
paddle::framework::DDim ddim2 =
paddle::framework::make_ddim({1, 2, 3, 4, 5, 6});
paddle::framework::DDim ss = paddle::framework::slice_ddim(ddim2, 2, 5);
EXPECT_EQ(arity(ss), 3);
EXPECT_EQ(ss[0], 3);
EXPECT_EQ(ss[1], 4);
EXPECT_EQ(ss[2], 5);
paddle::framework::DDim ss2 = paddle::framework::slice_ddim(ddim2, 0, 6);
EXPECT_EQ(arity(ss2), 6);
EXPECT_EQ(ss2[0], 1);
EXPECT_EQ(ss2[1], 2);
EXPECT_EQ(ss2[2], 3);
EXPECT_EQ(ss2[3], 4);
EXPECT_EQ(ss2[4], 5);
EXPECT_EQ(ss2[5], 6);
} }
TEST(DDim, Print) { TEST(DDim, Print) {
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册