提交 fc2e6f1c 编写于 作者: C caoying03

Merge branch 'develop' into print_attention_weight

#!/bin/bash
set -e
readonly VERSION="3.8"
version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1
fi
clang-format $@
...@@ -24,4 +24,5 @@ cmake-build-* ...@@ -24,4 +24,5 @@ cmake-build-*
python/paddle/v2/framework/core.so python/paddle/v2/framework/core.so
CMakeFiles CMakeFiles
cmake_install.cmake cmake_install.cmake
paddle/.timestamp
python/paddlepaddle.egg-info/
...@@ -17,14 +17,20 @@ ...@@ -17,14 +17,20 @@
- id: detect-private-key - id: detect-private-key
files: (?!.*third_party)^.*$ | (?!.*book)^.*$ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git - repo: local
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
hooks: hooks:
- id: clang-formater - id: clang-format-with-version-check
- repo: https://github.com/dnephin/pre-commit-golang name: clang-format
sha: e4693a4c282b4fc878eda172a929f7a6508e7d16 description: Format files with ClangFormat.
entry: ./.clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks: hooks:
- id: go-fmt - id: go-fmt
files: (.*\.go) types:
- id: go-lint - go
files: (.*\.go) - id: gometalinter
types:
- go
...@@ -4,6 +4,7 @@ cache: ...@@ -4,6 +4,7 @@ cache:
- $HOME/.ccache - $HOME/.ccache
- $HOME/.cache/pip - $HOME/.cache/pip
- $TRAVIS_BUILD_DIR/build/third_party - $TRAVIS_BUILD_DIR/build/third_party
- $TRAVIS_BUILD_DIR/build_android/third_party
sudo: required sudo: required
dist: trusty dist: trusty
os: os:
...@@ -11,6 +12,7 @@ os: ...@@ -11,6 +12,7 @@ os:
env: env:
- JOB=build_doc - JOB=build_doc
- JOB=check_style - JOB=check_style
- JOB=build_android
addons: addons:
apt: apt:
packages: packages:
...@@ -35,10 +37,12 @@ before_install: ...@@ -35,10 +37,12 @@ before_install:
- if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi - if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version. # protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker - pip install -r $TRAVIS_BUILD_DIR/python/requirements.txt
- pip install rarfile - pip install wheel sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit LinkChecker
- curl https://glide.sh/get | bash - curl https://glide.sh/get | bash
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)" - eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- go get -u github.com/alecthomas/gometalinter
- gometalinter --install
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
......
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
# limitations under the License # limitations under the License
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
include(system) include(system)
...@@ -37,6 +36,8 @@ include(simd) ...@@ -37,6 +36,8 @@ include(simd)
################################ Configurations ####################################### ################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND})
option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND})
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
...@@ -54,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) ...@@ -54,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON) option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -75,6 +77,10 @@ if(ANDROID) ...@@ -75,6 +77,10 @@ if(ANDROID)
"Disable PYTHON when cross-compiling for Android" FORCE) "Disable PYTHON when cross-compiling for Android" FORCE)
set(WITH_RDMA OFF CACHE STRING set(WITH_RDMA OFF CACHE STRING
"Disable RDMA when cross-compiling for Android" FORCE) "Disable RDMA when cross-compiling for Android" FORCE)
set(WITH_MKLDNN OFF CACHE STRING
"Disable MKLDNN when cross-compiling for Android" FORCE)
set(WITH_MKLML OFF CACHE STRING
"Disable MKLML package when cross-compiling for Android" FORCE)
endif(ANDROID) endif(ANDROID)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
...@@ -88,6 +94,7 @@ endif() ...@@ -88,6 +94,7 @@ endif()
######################################################################################## ########################################################################################
include(external/mklml) # download mklml package
include(external/zlib) # download, build, install zlib include(external/zlib) # download, build, install zlib
include(external/gflags) # download, build, install gflags include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog include(external/glog) # download, build, install glog
...@@ -95,6 +102,7 @@ include(external/gtest) # download, build, install gtest ...@@ -95,6 +102,7 @@ include(external/gtest) # download, build, install gtest
include(external/protobuf) # download, build, install protobuf include(external/protobuf) # download, build, install protobuf
include(external/python) # download, build, install python include(external/python) # download, build, install python
include(external/openblas) # download, build, install openblas include(external/openblas) # download, build, install openblas
include(external/mkldnn) # download, build, install mkldnn
include(external/swig) # download, build, install swig include(external/swig) # download, build, install swig
include(external/warpctc) # download, build, install warpctc include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any include(external/any) # download libn::any
...@@ -114,8 +122,8 @@ include(version) # set PADDLE_VERSION ...@@ -114,8 +122,8 @@ include(version) # set PADDLE_VERSION
include(coveralls) # set code coverage include(coveralls) # set code coverage
include_directories("${PROJ_ROOT}") include_directories("${PADDLE_SOURCE_DIR}")
include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${PADDLE_SOURCE_DIR}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c") include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
include_directories(${Boost_INCLUDE_DIRS}) include_directories(${Boost_INCLUDE_DIRS})
...@@ -130,14 +138,19 @@ set(EXTERNAL_LIBS ...@@ -130,14 +138,19 @@ set(EXTERNAL_LIBS
) )
if(WITH_GPU) if(WITH_GPU)
list(APPEND EXTERNAL_LIB ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO) if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIB ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
endif(NOT WITH_DSO) endif(NOT WITH_DSO)
endif(WITH_GPU) endif(WITH_GPU)
if(WITH_MKLDNN)
list(APPEND EXTERNAL_LIBS ${MKLDNN_LIB} ${MKLDNN_IOMP_LIB})
endif()
if(USE_NNPACK) if(USE_NNPACK)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB} "rt") include(external/nnpack)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIBS})
endif(USE_NNPACK) endif(USE_NNPACK)
add_subdirectory(proto) add_subdirectory(proto)
...@@ -152,10 +165,12 @@ if(WITH_GOLANG) ...@@ -152,10 +165,12 @@ if(WITH_GOLANG)
add_subdirectory(go) add_subdirectory(go)
endif(WITH_GOLANG) endif(WITH_GOLANG)
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
add_subdirectory(paddle) add_subdirectory(paddle)
if(WITH_PYTHON) if(WITH_PYTHON)
add_subdirectory(python) add_subdirectory(python)
endif() endif()
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
...@@ -25,27 +25,26 @@ COPY ./paddle/scripts/docker/root/ /root/ ...@@ -25,27 +25,26 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y \ apt-get install -y \
git python-pip python-dev openssh-server bison \ git python-pip python-dev openssh-server bison \
wget unzip tar xz-utils bzip2 gzip coreutils ntp \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \ curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-numpy python-matplotlib gcc g++ \ python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \ automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \ liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \ net-tools && \
apt-get clean -y apt-get clean -y
# Install Go and glide # Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \ RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \
tar -C /usr/local -xzf go.tgz && \ tar -xz -C /usr/local && \
mkdir /root/gopath && \ mkdir /root/gopath && \
mkdir /root/gopath/bin && \ mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \ mkdir /root/gopath/src
rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT. # should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# install glide # install glide
RUN curl -q https://glide.sh/get | sh RUN curl -s -q https://glide.sh/get | sh
# git credential to skip password typing # git credential to skip password typing
RUN git config --global credential.helper store RUN git config --global credential.helper store
...@@ -56,19 +55,23 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 ...@@ -56,19 +55,23 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter # FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter
# version util jupyter fixes this issue. # version util jupyter fixes this issue.
RUN pip install --upgrade pip && \ RUN pip install --upgrade pip && \
pip install -U 'protobuf==3.1.0' && \ pip install -U wheel && \
pip install -U wheel pillow BeautifulSoup && \
pip install -U docopt PyYAML sphinx && \ pip install -U docopt PyYAML sphinx && \
pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \ pip install -U sphinx-rtd-theme==0.1.9 recommonmark
pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \
RUN pip install pre-commit 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install rarfile pip install opencv-python
COPY ./python/requirements.txt /root/
RUN pip install -r /root/requirements.txt
# To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use # To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use
# the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2 # the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2
RUN apt-get install -y libssl-dev libffi-dev RUN apt-get install -y libssl-dev libffi-dev
RUN pip install certifi urllib3[secure] RUN pip install certifi urllib3[secure]
# Install woboq_codebrowser to /woboq # Install woboq_codebrowser to /woboq
RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \ RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \
(cd /woboq \ (cd /woboq \
......
...@@ -14,6 +14,17 @@ RUN apt-get update && \ ...@@ -14,6 +14,17 @@ RUN apt-get update && \
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \ wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
apt-get clean -y apt-get clean -y
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# git credential to skip password typing # git credential to skip password typing
RUN git config --global credential.helper store RUN git config --global credential.helper store
......
...@@ -72,7 +72,7 @@ We provide [English](http://doc.paddlepaddle.org/develop/doc/) and ...@@ -72,7 +72,7 @@ We provide [English](http://doc.paddlepaddle.org/develop/doc/) and
- [Deep Learning 101](http://book.paddlepaddle.org/index.html) - [Deep Learning 101](http://book.paddlepaddle.org/index.html)
You might want to start from the this online interactive book that can run in Jupyter Notebook. You might want to start from this online interactive book that can run in Jupyter Notebook.
- [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html) - [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html)
......
...@@ -15,23 +15,44 @@ ...@@ -15,23 +15,44 @@
set(CBLAS_FOUND OFF) set(CBLAS_FOUND OFF)
## Find MKL First. ## Find MKLML First.
set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs") if(WITH_MKLML AND MKLML_INC_DIR AND MKLML_LIB)
set(MKL_ROOT ${INTEL_ROOT}/mkl CACHE PATH "Folder contains MKL") set(CBLAS_FOUND ON)
set(CBLAS_PROVIDER MKLML)
set(CBLAS_INC_DIR ${MKLML_INC_DIR})
set(CBLAS_LIBRARIES ${MKLML_LIB})
add_definitions(-DPADDLE_USE_MKLML)
add_definitions(-DLAPACK_FOUND)
message(STATUS "Found cblas and lapack in MKLML "
"(include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
return()
endif()
## Then find MKL.
set(INTEL_MKL_ROOT "/opt/intel/mkl" CACHE PATH "Folder contains intel mkl libs")
set(MKL_ROOT $ENV{MKL_ROOT} CACHE PATH "Folder contains env MKL")
set(MKL_INCLUDE_SEARCH_PATHS
${MKL_ROOT}/include
${INTEL_MKL_ROOT}/include)
set(MKL_LIB_SEARCH_PATHS
${MKL_ROOT}/lib
${MKL_ROOT}/lib/intel64
${INTEL_MKL_ROOT}/lib
${INTEL_MKL_ROOT}/lib/intel64)
find_path(MKL_INC_DIR mkl.h PATHS find_path(MKL_INC_DIR mkl.h PATHS
${MKL_ROOT}/include) ${MKL_INCLUDE_SEARCH_PATHS})
find_path(MKL_LAPACK_INC_DIR mkl_lapacke.h PATHS find_path(MKL_LAPACK_INC_DIR mkl_lapacke.h PATHS
${MKL_ROOT}/include) ${MKL_INCLUDE_SEARCH_PATHS})
find_library(MKL_CORE_LIB NAMES mkl_core PATHS find_library(MKL_CORE_LIB NAMES mkl_core PATHS
${MKL_ROOT}/lib ${MKL_LIB_SEARCH_PATHS})
${MKL_ROOT}/lib/intel64)
find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS
${MKL_ROOT}/lib ${MKL_LIB_SEARCH_PATHS})
${MKL_ROOT}/lib/intel64)
find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS
${MKL_ROOT}/lib ${MKL_LIB_SEARCH_PATHS})
${MKL_ROOT}/lib/intel64)
if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64)
set(CBLAS_FOUND ON) set(CBLAS_FOUND ON)
......
...@@ -28,6 +28,10 @@ if(NOT WITH_TIMER) ...@@ -28,6 +28,10 @@ if(NOT WITH_TIMER)
add_definitions(-DPADDLE_DISABLE_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER)
endif(NOT WITH_TIMER) endif(NOT WITH_TIMER)
if(USE_EIGEN_FOR_BLAS)
add_definitions(-DPADDLE_USE_EIGEN_FOR_BLAS)
endif(USE_EIGEN_FOR_BLAS)
if(NOT WITH_PROFILER) if(NOT WITH_PROFILER)
add_definitions(-DPADDLE_DISABLE_PROFILER) add_definitions(-DPADDLE_DISABLE_PROFILER)
endif(NOT WITH_PROFILER) endif(NOT WITH_PROFILER)
...@@ -67,6 +71,28 @@ else() ...@@ -67,6 +71,28 @@ else()
include_directories(${CUDA_TOOLKIT_INCLUDE}) include_directories(${CUDA_TOOLKIT_INCLUDE})
endif(NOT WITH_GPU) endif(NOT WITH_GPU)
if(WITH_MKLDNN)
add_definitions(-DPADDLE_USE_MKLDNN)
if (WITH_MKLML AND MKLDNN_IOMP_DIR)
message(STATUS "Enable Intel OpenMP at ${MKLDNN_IOMP_DIR}")
set(OPENMP_FLAGS "-fopenmp")
set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}")
else()
find_package(OpenMP)
if(OPENMP_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
else()
message(WARNING "Can not find OpenMP."
"Some performance features in MKLDNN may not be available")
endif()
endif()
endif(WITH_MKLDNN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
...@@ -102,12 +128,19 @@ if(WITH_GOLANG) ...@@ -102,12 +128,19 @@ if(WITH_GOLANG)
message(FATAL_ERROR "no glide executeble found: $ENV{GOPATH}/bin/glide") message(FATAL_ERROR "no glide executeble found: $ENV{GOPATH}/bin/glide")
endif() endif()
add_custom_target(go_vendor) # this command will only run when the file it depends is missing
add_custom_command(TARGET go_vendor # or has changed, or the output is missing.
add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/glide
COMMAND env GOPATH=${GOPATH} ${GLIDE} install COMMAND env GOPATH=${GOPATH} ${GLIDE} install
COMMAND touch ${CMAKE_BINARY_DIR}/glide
DEPENDS ${PADDLE_SOURCE_DIR}/go/glide.lock
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go" WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go"
) )
add_dependencies(go_vendor go_path)
# depends on the custom command which outputs
# ${CMAKE_BINARY_DIR}/glide, the custom command does not need to
# run every time this target is built.
add_custom_target(go_vendor DEPENDS ${CMAKE_BINARY_DIR}/glide go_path)
endif() endif()
endif(WITH_GOLANG) endif(WITH_GOLANG)
...@@ -27,7 +27,8 @@ set(IGNORE_PATTERN ...@@ -27,7 +27,8 @@ set(IGNORE_PATTERN
.*cblas\\.h.* .*cblas\\.h.*
.*\\.pb\\.txt .*\\.pb\\.txt
.*LtrDataProvider.* .*LtrDataProvider.*
.*MultiDataProvider.*) .*MultiDataProvider.*
.*pb.*)
# add_style_check_target # add_style_check_target
# #
...@@ -41,27 +42,21 @@ macro(add_style_check_target TARGET_NAME) ...@@ -41,27 +42,21 @@ macro(add_style_check_target TARGET_NAME)
if(WITH_STYLE_CHECK) if(WITH_STYLE_CHECK)
set(SOURCES_LIST ${ARGN}) set(SOURCES_LIST ${ARGN})
list(REMOVE_DUPLICATES SOURCES_LIST) list(REMOVE_DUPLICATES SOURCES_LIST)
list(SORT SOURCES_LIST)
foreach(filename ${SOURCES_LIST}) foreach(filename ${SOURCES_LIST})
set(LINT ON)
foreach(pattern ${IGNORE_PATTERN}) foreach(pattern ${IGNORE_PATTERN})
if(filename MATCHES ${pattern}) if(filename MATCHES ${pattern})
message(STATUS "DROP LINT ${filename}") list(REMOVE_ITEM SOURCES_LIST ${filename})
set(LINT OFF)
endif() endif()
endforeach() endforeach()
if(LINT MATCHES ON)
get_filename_component(base_filename ${filename} NAME)
set(CUR_GEN ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.cpplint)
add_custom_command(OUTPUT ${CUR_GEN}
PRE_BUILD
COMMAND env ${py_env} "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py"
"--filter=${STYLE_FILTER}"
"--write-success=${CUR_GEN}" ${filename}
DEPENDS ${filename}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endforeach() endforeach()
if(SOURCES_LIST)
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/scripts/cpplint.py"
"--filter=${STYLE_FILTER}"
${SOURCES_LIST}
COMMENT "cpplint: Checking source code style"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endif() endif()
endmacro() endmacro()
...@@ -108,6 +108,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -108,6 +108,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a") IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android) SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
SET(CMAKE_SYSTEM_PROCESSOR aarch64)
ENDIF() ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-") SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF() ENDIF()
...@@ -166,7 +167,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -166,7 +167,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a") IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a) LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF() ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}") STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
...@@ -193,6 +194,10 @@ ELSE() ...@@ -193,6 +194,10 @@ ELSE()
SET(CMAKE_ANDROID_STANDALONE_TOOLCHAIN ${ANDROID_STANDALONE_TOOLCHAIN}) SET(CMAKE_ANDROID_STANDALONE_TOOLCHAIN ${ANDROID_STANDALONE_TOOLCHAIN})
ENDIF() ENDIF()
SET(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ABI}) SET(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ABI})
SET(CMAKE_ANDROID_ARM_MODE ${ANDROID_ARM_MODE}) IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(CMAKE_ANDROID_ARM_NEON ${ANDROID_ARM_NEON}) SET(CMAKE_ANDROID_ARM_MODE ${ANDROID_ARM_MODE})
IF(ANDROID_ABI STREQUAL "armeabi-v7a")
SET(CMAKE_ANDROID_ARM_NEON ${ANDROID_ARM_NEON})
ENDIF()
ENDIF()
ENDIF() ENDIF()
...@@ -2,7 +2,7 @@ if(NOT WITH_GPU) ...@@ -2,7 +2,7 @@ if(NOT WITH_GPU)
return() return()
endif() endif()
set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT") set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT")
find_path(CUDNN_INCLUDE_DIR cudnn.h find_path(CUDNN_INCLUDE_DIR cudnn.h
PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include
$ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
......
...@@ -7,8 +7,8 @@ INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any) ...@@ -7,8 +7,8 @@ INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any)
ExternalProject_Add( ExternalProject_Add(
extern_lib_any extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/thelink2012/any.git" GIT_REPOSITORY "https://github.com/PaddlePaddle/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020" GIT_TAG "15595d8324be9e8a9a80d9ae442fdd12bd66df5d"
PREFIX ${ANY_SOURCE_DIR} PREFIX ${ANY_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) ...@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048" GIT_TAG "master"
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -28,7 +28,14 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) ...@@ -28,7 +28,14 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
ExternalProject_Add( ExternalProject_Add(
extern_gflags extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/gflags/gflags.git" # TODO(yiwang): The annoying warnings mentioned in
# https://github.com/PaddlePaddle/Paddle/issues/3277 are caused by
# gflags. I fired a PR https://github.com/gflags/gflags/pull/230
# to fix it. Before it gets accepted by the gflags team, we use
# my personal fork, which contains above fix, temporarily. Let's
# change this back to the official Github repo once my PR is
# merged.
GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git"
PREFIX ${GFLAGS_SOURCES_DIR} PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -52,6 +52,7 @@ ExternalProject_Add( ...@@ -52,6 +52,7 @@ ExternalProject_Add(
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog) ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog) LIST(APPEND external_project_dependencies glog)
...@@ -34,9 +34,15 @@ IF(WITH_TESTING) ...@@ -34,9 +34,15 @@ IF(WITH_TESTING)
"${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE) "${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE)
ENDIF(WIN32) ENDIF(WIN32)
IF(WITH_MKLML)
# wait for mklml downloading completed
SET(GTEST_DEPENDS ${MKLML_PROJECT})
ENDIF()
ExternalProject_Add( ExternalProject_Add(
extern_gtest extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${GTEST_DEPENDS}
GIT_REPOSITORY "https://github.com/google/googletest.git" GIT_REPOSITORY "https://github.com/google/googletest.git"
GIT_TAG "release-1.8.0" GIT_TAG "release-1.8.0"
PREFIX ${GTEST_SOURCES_DIR} PREFIX ${GTEST_SOURCES_DIR}
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_MKLDNN})
return()
ENDIF(NOT ${WITH_MKLDNN})
INCLUDE(ExternalProject)
SET(MKLDNN_PROJECT "extern_mkldnn")
SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with MKLDNN in Paddle yet."
"Force WITH_MKLDNN=OFF")
SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in Windows and MacOS" FORCE)
return()
ENDIF()
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE)
MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR})
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
SET(MKLDNN_MKLROOT ${MKLML_ROOT})
SET(MKLDNN_IOMP_LIB ${MKLML_IOMP_LIB})
SET(MKLDNN_IOMP_DIR ${MKLML_LIB_DIR})
MESSAGE(STATUS "Build MKLDNN with ${MKLDNN_MKLROOT}")
ENDIF()
ExternalProject_Add(
${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "v0.9"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DMKLROOT=${MKLDNN_MKLROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
-DMKLROOT:PATH=${MKLDNN_MKLROOT}
)
ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIB}")
LIST(APPEND external_project_dependencies mkldnn)
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_MKLML})
return()
ENDIF(NOT ${WITH_MKLML})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with MKLML in Paddle yet."
"Force WITH_MKLML=OFF")
SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(MKLML_PROJECT "extern_mklml")
SET(MKLML_VER "mklml_lnx_2018.0.20170720")
SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.9/${MKLML_VER}.tgz")
SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
SET(MKLML_DST_DIR "mklml")
SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
SET(MKLML_ROOT ${MKLML_INSTALL_DIR}/${MKLML_VER})
SET(MKLML_INC_DIR ${MKLML_ROOT}/include)
SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
INCLUDE_DIRECTORIES(${MKLML_INC_DIR})
FILE(WRITE ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(MKLML)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${MKLML_VER}\n"
" DESTINATION ${MKLML_DST_DIR})\n")
ExternalProject_Add(
${MKLML_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${MKLML_SOURCE_DIR}
DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate -qO- ${MKLML_URL} | tar xz -C ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLML_INSTALL_ROOT}
)
ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB})
ADD_DEPENDENCIES(mklml ${MKLML_PROJECT})
LIST(APPEND external_project_dependencies mklml)
...@@ -7,10 +7,24 @@ set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK") ...@@ -7,10 +7,24 @@ set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK")
find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include) find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include)
find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib) find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib)
find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib) find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib)
find_library(NNPACK_UKERNELS_LIB NAMES nnpack_ukernels PATHS ${NNPACK_ROOT}/lib)
find_library(NNPACK_CPUFEATURES_LIB NAMES cpufeatures PATHS ${NNPACK_ROOT}/lib)
if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB) if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB)
set(NNPACK_FOUND ON) set(NNPACK_FOUND ON)
INCLUDE_DIRECTORIES(${NNPACK_INC_DIR}) INCLUDE_DIRECTORIES(${NNPACK_INC_DIR})
set(NNPACK_LIBS)
list(APPEND NNPACK_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB})
if (NNPACK_UKERNELS_LIB)
list(APPEND NNPACK_LIBS ${NNPACK_UKERNELS_LIB})
endif()
if (NNPACK_CPUFEATURES_LIB)
list(APPEND NNPACK_LIBS ${NNPACK_CPUFEATURES_LIB})
endif()
if(NOT ANDROID)
list(APPEND NNPACK_LIBS "rt")
endif()
else() else()
message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})") message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})")
endif() endif()
...@@ -69,9 +69,22 @@ ENDIF(NOT ${CBLAS_FOUND}) ...@@ -69,9 +69,22 @@ ENDIF(NOT ${CBLAS_FOUND})
MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}") MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}")
INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
ADD_LIBRARY(cblas STATIC IMPORTED) # FIXME(gangliao): generate cblas target to track all high performance
SET_PROPERTY(TARGET cblas PROPERTY IMPORTED_LOCATION ${CBLAS_LIBRARIES}) # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
IF(${CBLAS_PROVIDER} MATCHES MKL)
ADD_LIBRARY(cblas SHARED ${dummyfile})
ELSE()
ADD_LIBRARY(cblas STATIC ${dummyfile})
ENDIF()
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
IF(NOT ${CBLAS_FOUND}) IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas) ADD_DEPENDENCIES(cblas extern_openblas)
LIST(APPEND external_project_dependencies cblas) LIST(APPEND external_project_dependencies cblas)
ELSE()
IF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
ADD_DEPENDENCIES(cblas mklml)
ENDIF()
ENDIF(NOT ${CBLAS_FOUND}) ENDIF(NOT ${CBLAS_FOUND})
...@@ -24,7 +24,6 @@ IF(WITH_PYTHON) ...@@ -24,7 +24,6 @@ IF(WITH_PYTHON)
ENDIF(WITH_PYTHON) ENDIF(WITH_PYTHON)
SET(py_env "") SET(py_env "")
SET(USE_VIRTUALENV_FOR_TEST 1)
IF(PYTHONINTERP_FOUND) IF(PYTHONINTERP_FOUND)
find_python_module(pip REQUIRED) find_python_module(pip REQUIRED)
find_python_module(numpy REQUIRED) find_python_module(numpy REQUIRED)
......
...@@ -110,7 +110,7 @@ set(COMMON_FLAGS ...@@ -110,7 +110,7 @@ set(COMMON_FLAGS
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=sign-compare -Wno-error=sign-compare
-Wno-error=unused-local-typedefs -Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in Pybind11 -Wno-error=parentheses-equality # Warnings in pybind11
) )
set(GPU_COMMON_FLAGS set(GPU_COMMON_FLAGS
...@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS ...@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=unused-local-typedefs -Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
) )
if (APPLE) if (APPLE)
...@@ -189,6 +190,7 @@ endif() ...@@ -189,6 +190,7 @@ endif()
# Modern gpu architectures: Pascal # Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0") if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60") list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
list(APPEND CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
endif() endif()
# Custom gpu architecture # Custom gpu architecture
......
...@@ -104,6 +104,7 @@ function(merge_static_libs TARGET_NAME) ...@@ -104,6 +104,7 @@ function(merge_static_libs TARGET_NAME)
foreach(lib ${libs}) foreach(lib ${libs})
list(APPEND libs_deps ${${lib}_LIB_DEPENDS}) list(APPEND libs_deps ${${lib}_LIB_DEPENDS})
endforeach() endforeach()
list(REMOVE_DUPLICATES libs_deps)
if(APPLE) # Use OSX's libtool to merge archives if(APPLE) # Use OSX's libtool to merge archives
# To produce a library we need at least one source file. # To produce a library we need at least one source file.
...@@ -127,7 +128,7 @@ function(merge_static_libs TARGET_NAME) ...@@ -127,7 +128,7 @@ function(merge_static_libs TARGET_NAME)
# Get the file names of the libraries to be merged # Get the file names of the libraries to be merged
set(libfiles ${libfiles} $<TARGET_FILE:${lib}>) set(libfiles ${libfiles} $<TARGET_FILE:${lib}>)
endforeach() endforeach()
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a"
COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles}) COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles})
else() # general UNIX: use "ar" to extract objects and re-add to a common lib else() # general UNIX: use "ar" to extract objects and re-add to a common lib
...@@ -145,11 +146,11 @@ function(merge_static_libs TARGET_NAME) ...@@ -145,11 +146,11 @@ function(merge_static_libs TARGET_NAME)
DEPENDS ${lib} ${objdir} DEPENDS ${lib} ${objdir}
WORKING_DIRECTORY ${objdir}) WORKING_DIRECTORY ${objdir})
# Empty dummy source file that goes into merged library # Empty dummy source file that goes into merged library
set(mergebase ${lib}.mergebase.c) set(mergebase ${lib}.mergebase.c)
add_custom_command(OUTPUT ${mergebase} add_custom_command(OUTPUT ${mergebase}
COMMAND ${CMAKE_COMMAND} -E touch ${mergebase} COMMAND ${CMAKE_COMMAND} -E touch ${mergebase}
DEPENDS ${objlistfile}) DEPENDS ${objlistfile})
list(APPEND mergebases "${mergebase}") list(APPEND mergebases "${mergebase}")
endforeach() endforeach()
...@@ -184,6 +185,16 @@ function(cc_library TARGET_NAME) ...@@ -184,6 +185,16 @@ function(cc_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
endif() endif()
# cpplint code style
foreach(source_file ${cc_library_SRCS})
string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file})
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
endif()
endforeach()
add_style_check_target(${TARGET_NAME} ${cc_library_SRCS} ${cc_library_HEADERS})
else(cc_library_SRCS) else(cc_library_SRCS)
if (cc_library_DEPS) if (cc_library_DEPS)
merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) merge_static_libs(${TARGET_NAME} ${cc_library_DEPS})
...@@ -234,6 +245,14 @@ function(nv_library TARGET_NAME) ...@@ -234,6 +245,14 @@ function(nv_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) add_dependencies(${TARGET_NAME} ${nv_library_DEPS})
target_link_libraries(${TARGET_NAME} ${nv_library_DEPS}) target_link_libraries(${TARGET_NAME} ${nv_library_DEPS})
endif() endif()
# cpplint code style
foreach(source_file ${nv_library_SRCS})
string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file})
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
endif()
endforeach()
add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS})
else(nv_library_SRCS) else(nv_library_SRCS)
if (nv_library_DEPS) if (nv_library_DEPS)
merge_static_libs(${TARGET_NAME} ${nv_library_DEPS}) merge_static_libs(${TARGET_NAME} ${nv_library_DEPS})
...@@ -285,8 +304,22 @@ function(go_library TARGET_NAME) ...@@ -285,8 +304,22 @@ function(go_library TARGET_NAME)
set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
endif() endif()
# Add dummy code to support `make target_name` under Terminal Command
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
# This custom command will always run since it depends on a not
# existing file.
add_custom_command(
OUTPUT dummy_rebulid_${TARGET_NAME}
COMMAND cmake -E touch ${dummyfile}
)
# Create a custom target that depends on the custom command output
# file, so the custom command can be referenced as a dependency by
# `add_dependencies`.
add_custom_target(rebuild_${TARGET_NAME}
DEPENDS dummy_rebulid_${TARGET_NAME}
)
# Add dummy code to support `make target_name` under Terminal Command
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
if (go_library_SHARED OR go_library_shared) if (go_library_SHARED OR go_library_shared)
add_library(${TARGET_NAME} SHARED ${dummyfile}) add_library(${TARGET_NAME} SHARED ${dummyfile})
...@@ -297,6 +330,12 @@ function(go_library TARGET_NAME) ...@@ -297,6 +330,12 @@ function(go_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${go_library_DEPS}) add_dependencies(${TARGET_NAME} ${go_library_DEPS})
endif(go_library_DEPS) endif(go_library_DEPS)
# The "source file" of the library is `${dummyfile}` which never
# change, so the target will never rebuild. Make the target depends
# on the custom command that touches the library "source file", so
# rebuild will always happen.
add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME})
set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}")
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
...@@ -337,7 +376,7 @@ function(go_test TARGET_NAME) ...@@ -337,7 +376,7 @@ function(go_test TARGET_NAME)
string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test -race
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
".${CMAKE_CURRENT_SOURCE_REL_DIR}" ".${CMAKE_CURRENT_SOURCE_REL_DIR}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
...@@ -364,3 +403,16 @@ function(py_proto_compile TARGET_NAME) ...@@ -364,3 +403,16 @@ function(py_proto_compile TARGET_NAME)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS}) protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction() endfunction()
function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python
python2 ${py_test_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
...@@ -12,7 +12,7 @@ set(CPACK_PACKAGE_DESCRIPTION "") ...@@ -12,7 +12,7 @@ set(CPACK_PACKAGE_DESCRIPTION "")
set(CPACK_DEBIAN_PACKAGE_DEPENDS "libpython2.7-dev, libstdc++6, python-pip, curl, libgfortran3, python-pip-whl") set(CPACK_DEBIAN_PACKAGE_DEPENDS "libpython2.7-dev, libstdc++6, python-pip, curl, libgfortran3, python-pip-whl")
set(CPACK_DEBIAN_PACKAGE_SECTION Devel) set(CPACK_DEBIAN_PACKAGE_SECTION Devel)
set(CPACK_DEBIAN_PACKAGE_VERSION ${PADDLE_VERSION}) set(CPACK_DEBIAN_PACKAGE_VERSION ${PADDLE_VERSION})
set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${PROJ_ROOT}/paddle/scripts/deb/postinst") set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${PADDLE_SOURCE_DIR}/paddle/scripts/deb/postinst")
#set(CPACK_GENERATOR "DEB") #set(CPACK_GENERATOR "DEB")
# Start cpack # Start cpack
include (CMakePackageConfigHelpers) include (CMakePackageConfigHelpers)
......
...@@ -118,7 +118,6 @@ endfunction() ...@@ -118,7 +118,6 @@ endfunction()
macro(add_unittest_without_exec TARGET_NAME) macro(add_unittest_without_exec TARGET_NAME)
add_executable(${TARGET_NAME} ${ARGN}) add_executable(${TARGET_NAME} ${ARGN})
link_paddle_test(${TARGET_NAME}) link_paddle_test(${TARGET_NAME})
add_style_check_target(${TARGET_NAME} ${ARGN})
endmacro() endmacro()
# add_unittest # add_unittest
...@@ -142,17 +141,20 @@ endmacro() ...@@ -142,17 +141,20 @@ endmacro()
function(create_resources res_file output_file) function(create_resources res_file output_file)
add_custom_command( add_custom_command(
OUTPUT ${output_file} OUTPUT ${output_file}
COMMAND python ARGS ${PROJ_ROOT}/cmake/make_resource.py ${res_file} ${output_file} COMMAND python ARGS ${PADDLE_SOURCE_DIR}/cmake/make_resource.py ${res_file} ${output_file}
DEPENDS ${res_file} ${PROJ_ROOT}/cmake/make_resource.py) DEPENDS ${res_file} ${PADDLE_SOURCE_DIR}/cmake/make_resource.py)
endfunction() endfunction()
# Create a python unittest using run_python_tests.sh, # Create a python unittest using run_python_tests.sh,
# which takes care of making correct running environment # which takes care of making correct running environment
function(add_python_test TEST_NAME) function(add_python_test TEST_NAME)
add_test(NAME ${TEST_NAME} foreach(arg ${ARGN})
COMMAND env PADDLE_PACKAGE_DIR=${PADDLE_PYTHON_PACKAGE_DIR} get_filename_component(py_fn ${arg} NAME_WE)
bash ${PROJ_ROOT}/paddle/scripts/run_python_tests.sh set(TRG_NAME ${TEST_NAME}_${py_fn})
${USE_VIRTUALENV_FOR_TEST} ${PYTHON_EXECUTABLE} ${ARGN} add_test(NAME ${TRG_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) COMMAND env PYTHONPATH=${PADDLE_PYTHON_PACKAGE_DIR}
python2 ${arg}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endforeach()
endfunction() endfunction()
...@@ -4,7 +4,7 @@ set(tmp_version "HEAD") ...@@ -4,7 +4,7 @@ set(tmp_version "HEAD")
while ("${PADDLE_VERSION}" STREQUAL "") while ("${PADDLE_VERSION}" STREQUAL "")
execute_process( execute_process(
COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 ${tmp_version} COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 ${tmp_version}
WORKING_DIRECTORY ${PROJ_ROOT} WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE GIT_TAG_NAME OUTPUT_VARIABLE GIT_TAG_NAME
RESULT_VARIABLE GIT_RESULT RESULT_VARIABLE GIT_RESULT
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
......
...@@ -104,6 +104,11 @@ cross_channel_norm ...@@ -104,6 +104,11 @@ cross_channel_norm
------------------ ------------------
.. autoclass:: paddle.v2.layer.cross_channel_norm .. autoclass:: paddle.v2.layer.cross_channel_norm
:noindex: :noindex:
row_l2_norm
-----------
.. autoclass:: paddle.v2.layer.row_l2_norm
:noindex:
Recurrent Layers Recurrent Layers
================ ================
...@@ -198,6 +203,10 @@ identity_projection ...@@ -198,6 +203,10 @@ identity_projection
.. autoclass:: paddle.v2.layer.identity_projection .. autoclass:: paddle.v2.layer.identity_projection
:noindex: :noindex:
slice_projection
-------------------
.. autoclass:: paddle.v2.layer.slice_projection
:noindex:
table_projection table_projection
---------------- ----------------
...@@ -248,6 +257,16 @@ seq_concat ...@@ -248,6 +257,16 @@ seq_concat
.. autoclass:: paddle.v2.layer.seq_concat .. autoclass:: paddle.v2.layer.seq_concat
:noindex: :noindex:
kmax_sequence_score
-------------------
.. autoclass:: paddle.v2.layer.kmax_sequence_score
:noindex:
sub_nested_seq
--------------
.. autoclass:: paddle.v2.layer.sub_nested_seq
:noindex:
Reshaping Layers Reshaping Layers
================ ================
...@@ -316,6 +335,11 @@ scaling ...@@ -316,6 +335,11 @@ scaling
.. autoclass:: paddle.v2.layer.scaling .. autoclass:: paddle.v2.layer.scaling
:noindex: :noindex:
clip
----
.. autoclass:: paddle.v2.layer.clip
:noindex:
slope_intercept slope_intercept
--------------- ---------------
.. autoclass:: paddle.v2.layer.slope_intercept .. autoclass:: paddle.v2.layer.slope_intercept
...@@ -338,6 +362,11 @@ trans ...@@ -338,6 +362,11 @@ trans
.. autoclass:: paddle.v2.layer.trans .. autoclass:: paddle.v2.layer.trans
:noindex: :noindex:
scale_shift
-----------
.. autoclass:: paddle.v2.layer.scale_shift
:noindex:
Sampling Layers Sampling Layers
=============== ===============
...@@ -474,6 +503,11 @@ prelu ...@@ -474,6 +503,11 @@ prelu
.. autoclass:: paddle.v2.layer.prelu .. autoclass:: paddle.v2.layer.prelu
:noindex: :noindex:
gated_unit
-----------
.. autoclass:: paddle.v2.layer.gated_unit
:noindex:
Detection output Layer Detection output Layer
====================== ======================
......
## Auto Gradient Checker Design
## Backgraound:
- Operator forward computing is easy to check if the result is right because it has a clear definition. **But** backpropagation is a notoriously difficult algorithm to debug and get right:
- 1. you should get the right backpropagation formula according to the forward computation.
- 2. you should implement it right in CPP.
- 3. it's difficult to prepare test data.
- Auto gradient check gets a numeric gradient by forward Operator and use it as a reference of the backward Operator's result. It has several advantages:
- 1. numeric gradient checker only need forward operator.
- 2. user only need to prepare the input data for forward Operator.
## Mathematical Theory
The following two document from stanford has a detailed explanation of how to get numeric gradient and why it's useful.
- [Gradient checking and advanced optimization(en)](http://deeplearning.stanford.edu/wiki/index.php/Gradient_checking_and_advanced_optimization)
- [Gradient checking and advanced optimization(cn)](http://ufldl.stanford.edu/wiki/index.php/%E6%A2%AF%E5%BA%A6%E6%A3%80%E9%AA%8C%E4%B8%8E%E9%AB%98%E7%BA%A7%E4%BC%98%E5%8C%96)
## Numeric Gradient Implementation
### Python Interface
```python
def get_numeric_gradient(op,
input_values,
output_name,
input_to_check,
delta=0.005,
local_scope=None):
"""
Get Numeric Gradient for an operator's input.
:param op: C++ operator instance, could be an network
:param input_values: The input variables. Should be an dictionary, key is
variable name. Value is numpy array.
:param output_name: The final output variable name.
:param input_to_check: The input variable need to get gradient.
:param delta: The perturbation value for numeric gradient method. The
smaller delta is, the more accurate result will get. But if that delta is
too small, it could occur numerical stability problem.
:param local_scope: The local scope used for get_numeric_gradient.
:return: The gradient array in numpy format.
"""
```
### Explaination:
- Why need `output_name`
- One Operator may have multiple Output, you can get independent gradient from each Output. So user should set one output to calculate.
- Why need `input_to_check`
- One operator may have multiple inputs. Gradient Op can calculate the gradient of these Inputs at the same time. But Numeric Gradient needs to calculate them one by one. So `get_numeric_gradient` is designed to calculate the gradient for one input. If you need to compute multiple inputs, you can call `get_numeric_gradient` multiple times.
### Core Algorithm Implementation
```python
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
# add delta to it, run op and then get the sum of the result tensor.
x_pos = origin + delta
tensor_to_check.set_float_element(i, x_pos)
y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor.
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
# restore old value
tensor_to_check.set_float_element(i, origin)
# compute the gradient of this element and store it into a numpy array.
gradient_flat[i] = (y_pos - y_neg) / delta / 2
# reshape the gradient result to the shape of the source tensor.
return gradient_flat.reshape(tensor_to_check.get_dims())
```
## Auto Graident Checker Framework
Each Operator Kernel has three kinds of Gradient:
- 1. Numeric Gradient
- 2. CPU Operator Gradient
- 3. GPU Operator Gradient(if supported)
Numeric Gradient Only relies on forward Operator. So we use Numeric Gradient as the reference value.
- 1. calculate the numeric gradient.
- 2. calculate CPU kernel Gradient with the backward Operator and compare it with the numeric gradient.
- 3. calculate GPU kernel Gradient with the backward Operator and compare it with the numeric gradient.(if support GPU)
#### Python Interface
```python
def check_grad(self,
forward_op,
input_vars,
inputs_to_check,
output_name,
no_grad_set=None,
only_cpu=False,
max_relative_error=0.005):
"""
:param forward_op: used to create backward_op
:param input_vars: numpy value of input variable. The following
computation will use these variables.
:param inputs_to_check: inputs var names that should check gradient.
:param output_name: output name that used to
:param max_relative_error: The relative tolerance parameter.
:param no_grad_set: used when create backward ops
:param only_cpu: only compute and check gradient on cpu kernel.
:return:
"""
```
### How to check if two numpy array is close enough?
if `abs_numeric_grad` is nearly zero, then use abs error for numeric_grad, not relative
```python
numeric_grad = ...
operator_grad = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
abs_numeric_grad = numpy.abs(numeric_grad)
# if abs_numeric_grad is nearly zero, then use abs error for numeric_grad, not relative
# error.
abs_numeric_grad[abs_numeric_grad < 1e-3] = 1
diff_mat = numpy.abs(abs_numeric_grad - operator_grad) / abs_numeric_grad
max_diff = numpy.max(diff_mat)
```
#### Notes:
1,The Input data for auto gradient checker should be reasonable to avoid numeric problem.
#### Refs:
- [Gradient checking and advanced optimization(en)](http://deeplearning.stanford.edu/wiki/index.php/Gradient_checking_and_advanced_optimization)
- [Gradient checking and advanced optimization(cn)](http://ufldl.stanford.edu/wiki/index.php/%E6%A2%AF%E5%BA%A6%E6%A3%80%E9%AA%8C%E4%B8%8E%E9%AB%98%E7%BA%A7%E4%BC%98%E5%8C%96)
# Alalysis of large model distributed training in Paddle
***NOTE: This is only some note for how we implemeted this scheme in V1, not a new design.***
## What is it
We often encounter cases that the embedding layer parameters(sparse) are so large that we can not store it in the trainer's memory when training. So we need to put them to several servers, and fetch them row by row instead of fetch all of the parameters.
## How to use
Specify command-line argument like `--loadsave_parameters_in_pserver=true --ports_num_for_sparse=1 --use_old_updater=1` when starting the paddle trainer. And also add something like `--ports_num_for_sparse=1 --pserver_num_threads=5` when starting pserver processes.
Accrodingly, configure your embedding layers like:
```python
SPARSE_REMOTE=True
w1 = data_layer(name="w1", size=dict_size)
emb1 = embedding_layer(input=w1, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
w2 = data_layer(name="w2", size=dict_size)
emb2 = embedding_layer(input=w2, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
...
```
## Implementation details
```c++
enum MatType {
MAT_NORMAL,
MAT_NORMAL_SHARED,
MAT_VALUE_SHARED,
MAT_SPARSE_ROW_IDS,
MAT_SPARSE_ROW_AUTO_GROW,
MAT_CACHE_ROW,
MAT_SPARSE_ROW,
MAT_SPARSE_ROW_PREFETCH,
MAT_SPARSE_ROW_PREFETCH_FULL_SIZE,
};
```
`MAT_SPARSE_ROW_PREFETCH` is what we use when configured to fetch only row of matrix when training.
In `trainer_internal.cpp:L93 trainOneBatch`:
```c++
if (config_->getOptConfig().use_sparse_remote_updater()) {
REGISTER_TIMER("prefetch");
gradientMachine_->prefetch(inArgs);
parameterUpdater_->getParametersRemote();
}
```
When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver.
In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
```c++
if (fullSize) {
...
} else {
getParams = [&] {
parameterClient_->getParameterSparse(
/* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType);
};
applyL1 = [](Parameter& para, real decayRate) {
para.getMat(PARAMETER_VALUE)->applyL1(/*lr=*/1.0f, decayRate);
};
}
```
Calling `parameterClient_->getParameterSparse` will do remote call to pserver's `getParameterSparse`:
```c++
void ParameterServer2::getParameterSparse(const SendParameterRequest& request,
std::vector<Buffer>& inputBuffers,
SendParameterResponse* response,
std::vector<Buffer>* outputBuffers) {
(void)inputBuffers;
auto& buffer = *readWriteBuffer_;
size_t numReals = 0;
for (const auto& block : request.blocks()) {
numReals += getParameterConfig(block).dims(1);
}
buffer.resize(numReals);
VLOG(3) << "pserver: getParameterSparse, numReals=" << numReals;
ReadLockGuard guard(parameterMutex_);
size_t offset = 0;
for (const auto& block : request.blocks()) {
size_t width = getParameterConfig(block).dims(1);
Buffer buf = {buffer.data() + offset, width};
int type = request.send_back_parameter_type();
sendBackParameterSparse(block, type, response, &buf, width, outputBuffers);
offset += width;
}
}
```
`getParameterConfig(block).dims(1)` returns the width of the current "parameter block"(a shard of parameter object),
then `getParameterSparse` remote call returns only one row of data to the client.
...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future. ...@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election ### Trainer Election
One trainer will be elected as the one to save the model. When using One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to etcd, trainer ID is a randomly generated UUID, the trainer will
elect one trainer. When not using etcd, unique trainer IDs will be contact the master server requesting to save the model, and find out
given by the administrator, the trainer whose ID is "0" is elected to if itself is elected. When the master server is not used, unique
save the model. trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path ### Model Save Path
......
# Intel® MKL-DNN on PaddlePaddle: Design Doc
我们计划将Intel深度神经网络数学库(**MKL-DNN**\[[1](#references)\])集成到PaddlePaddle,充分展现英特尔平台的优势,有效提升PaddlePaddle在英特尔架构上的性能。
我们短期内的基本目标是:
- 完成常用layer的MKL-DNN实现。
- 完成常见深度神经网络VGG,GoogLeNet 和 ResNet的MKL-DNN实现。
## Contents
- [Overview](#overview)
- [Actions](#actions)
- [CMake](#cmake)
- [Layers](#layers)
- [Activations](#activations)
- [Unit Tests](#unit-tests)
- [Protobuf Messages](#protobuf-messages)
- [Python API](#python-api)
- [Demos](#demos)
- [Benchmarking](#benchmarking)
- [Others](#others)
- [Design Concerns](#design-concerns)
## Overview
我们会把MKL-DNN作为第三方库集成进PaddlePaddle,整体框架图
<div align="center">
<img src="image/overview.png" width=350><br/>
Figure 1. PaddlePaddle on IA.
</div>
## Actions
我们把集成方案大致分为了如下几个方面。
### CMake
我们会在`CMakeLists.txt`中会添加`WITH_MKLDNN`的选项,当设置这个值为`ON`的时候会启用编译MKL-DNN功能。同时会自动开启OpenMP用于提高MKL-DNN的性能。
同时,我们会引入`WITH_MKLML`选项,用于选择是否使用MKL-DNN自带的MKLML安装包。这个安装包可以独立于MKL-DNN使用,但是建议在开启MKL-DNN的同时也打开MKLML的开关,这样才能发挥最好的性能。
所以,我们会在`cmake/external`目录新建`mkldnn.cmake``mklml.cmake`文件,它们会在编译PaddlePaddle的时候下载对应的软件包,并放到PaddlePaddle的third party目录中。
**备注**:当`WITH_MKLML=ON`的时候,会优先使用这个包作为PaddlePaddle的CBLAS和LAPACK库,所以会稍微改动`cmake/cblas.cmake`中的逻辑。
### Layers
所有MKL-DNN相关的C++ layers,都会按照PaddlePaddle的目录结构存放在
`paddle/gserver/layers`中,并且文件名都会一以*Mkldnn*开头。
所有MKL-DNN的layers都会继承于一个叫做`MkldnnLayer`的父类,该父类继承于PaddlePaddle的基类`Layer`
### Activations
由于在PaddlePaddle中,激活函数是独立于layer概念的,所以会在`paddle/gserver/activations`目录下添加一个`MkldnnActivation.h`文件定义一些用于MKL-DNN的接口,实现方法还是会在`ActivationFunction.cpp`文件。
### Unit Tests
会在`paddle/gserver/test`目录下添加`test_Mkldnn.cpp``MkldnnTester.*`用于MKL-DNN的测试。
Activation的测试,计划在PaddlePaddle原有的测试文件上直接添加新的测试type。
### Protobuf Messages
根据具体layer的需求可能会在`proto/ModelConfig.proto`里面添加必要的选项。
### Python API
目前只考虑**v1 API**
计划在`python/paddle/trainer/config_parser.py`里面添加`use_mkldnn`这个选择,方便用户选择使用MKL-DNN的layers。
具体实现方式比如:
```python
use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0)))
if use_mkldnn
self.layer_type = mkldnn_*
```
所有MKL-DNN的layer type会以*mkldnn_*开头,以示区分。
并且可能在`python/paddle/trainer_config_helper`目录下的`activations.py ``layers.py`里面添加必要的MKL-DNN的接口。
### Demos
会在`v1_api_demo`目录下添加一个`mkldnn`的文件夹,里面放入一些用于MKL-DNN测试的demo脚本。
### Benchmarking
会考虑添加部分逻辑在`benchmark/paddle/image/run.sh`,添加使用MKL-DNN的测试。
### Others
1. 如果在使用MKL-DNN的情况下,会把CPU的Buffer对齐为64。
2. 深入PaddlePaddle,寻找有没有其他可以优化的可能,进一步优化。比如可能会用OpenMP改进SGD的更新性能。
## Design Concerns
为了更好的符合PaddlePaddle的代码风格\[[2](#references)\],同时又尽可能少的牺牲MKL-DNN的性能\[[3](#references)\]
我们总结出一些特别需要注意的点:
1. 使用**deviceId_**。为了尽可能少的在父类Layer中添加变量或者函数,我们决定使用已有的`deviceId_`变量来区分layer的属性,定义`-2``MkldnnLayer`特有的设备ID。
2. 重写父类Layer的**init**函数,修改`deviceId_``-2`,代表这个layer是用于跑在MKL-DNN的环境下。
3. 创建`MkldnnMatrix`,用于管理MKL-DNN会用到的相关memory函数、接口以及会用的到格式信息。
4. 创建`MkldnnBase`,定义一些除了layer和memory相关的类和函数。包括MKL-DNN会用到`MkldnnStream``CpuEngine`,和未来可能还会用到`FPGAEngine`等。
5.**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue``mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。
6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。
7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。
8. 关于MKLDNN参数的保存。由于MKLDNN参数的格式与PaddlePaddle原有的格式存在不一样的情况,所以需要在保存参数时同时保存该格式信息。目前准备扩展[Header](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/parameter/Parameter.h#L247)里面的`int32_t version`。这个值不管是在v1还是在v2里面,一直保存的是0,所以可以充分利用这个信息,定义一个枚举处理所有MKLDNN的参数格式,从而`MKLDNNLayer`就可以从输入的参数中获取需要的格式信息。
## References
1. [Intel Math Kernel Library for Deep Neural Networks (Intel MKL-DNN)](https://github.com/01org/mkl-dnn "Intel MKL-DNN")
2. [原来的方案](https://github.com/PaddlePaddle/Paddle/pull/3096)会引入**nextLayer**的信息。但是在PaddlePaddle中,无论是重构前的layer还是重构后的op,都不会想要知道next layer/op的信息。
3. MKL-DNN的高性能格式与PaddlePaddle原有的`NCHW`不同(PaddlePaddle中的CUDNN部分使用的也是`NCHW`,所以不存在这个问题),所以需要引入一个转换方法,并且只需要在必要的时候转换这种格式,才能更好的发挥MKL-DNN的性能。
...@@ -11,6 +11,15 @@ Paddle每次发新的版本,遵循以下流程: ...@@ -11,6 +11,15 @@ Paddle每次发新的版本,遵循以下流程:
* 编译这个版本的Ubuntu Deb包。如果失败,修复Ubuntu Deb包编译问题,Patch号加一,返回第二步。 * 编译这个版本的Ubuntu Deb包。如果失败,修复Ubuntu Deb包编译问题,Patch号加一,返回第二步。
* 使用Regression Test List作为检查列表,测试Docker镜像/ubuntu安装包的功能正确性 * 使用Regression Test List作为检查列表,测试Docker镜像/ubuntu安装包的功能正确性
* 如果失败,记录下所有失败的例子,在这个`release/版本号`分支中,修复所有bug后,Patch号加一,返回第二步 * 如果失败,记录下所有失败的例子,在这个`release/版本号`分支中,修复所有bug后,Patch号加一,返回第二步
* 编译这个版本的python wheel包,并发布到pypi。
* 由于pypi.python.org目前遵循[严格的命名规范PEP 513](https://www.python.org/dev/peps/pep-0513),在使用twine上传之前,需要重命名wheel包中platform相关的后缀,比如将`linux_x86_64`修改成`manylinux1_x86_64`
* pypi上的package名称为paddlepaddle和paddlepaddle_gpu,如果要上传GPU版本的包,需要修改build/python/setup.py中,name: "paddlepaddle_gpu"并重新打包wheel包:`python setup.py bdist_wheel`
* 上传方法:
```
cd build/python
pip install twine
twine upload dist/[package to upload]
```
4. 第三步完成后,将`release/版本号`分支合入master分支,并删除`release/版本号`分支。将master分支的合入commit打上tag,tag为`版本号`。同时再将`master`分支合入`develop`分支。最后删除`release/版本号`分支。 4. 第三步完成后,将`release/版本号`分支合入master分支,并删除`release/版本号`分支。将master分支的合入commit打上tag,tag为`版本号`。同时再将`master`分支合入`develop`分支。最后删除`release/版本号`分支。
5. 编译master分支的Docker发行镜像,发布到dockerhub。编译ubuntu的deb包,发布到github release页面 5. 编译master分支的Docker发行镜像,发布到dockerhub。编译ubuntu的deb包,发布到github release页面
6. 协同完成Release Note的书写 6. 协同完成Release Note的书写
......
...@@ -37,8 +37,8 @@ Scope is an association of a name to variable. All variables belong to `Scope`. ...@@ -37,8 +37,8 @@ Scope is an association of a name to variable. All variables belong to `Scope`.
```cpp ```cpp
class Scope { class Scope {
public: public:
Variable* CreateVariable(const std::string& name); Variable* NewVar(const std::string& name);
const Variable* GetVariable(const std::string& name) const; const Variable* FindVar(const std::string& name) const;
private: private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
...@@ -58,12 +58,12 @@ class Scope { ...@@ -58,12 +58,12 @@ class Scope {
public: public:
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {} Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const { Variable* FindVar(const std::string& name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
return it->second.get(); return it->second.get();
} else if (parent_ != nullptr) { } else if (parent_ != nullptr) {
return parent_->GetVariable(name); return parent_->FindVar(name);
} else { } else {
return nullptr; return nullptr;
} }
...@@ -95,10 +95,10 @@ class Scope { ...@@ -95,10 +95,10 @@ class Scope {
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr); static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);
// return nullptr if not found. // return nullptr if not found.
Variable* GetVariable(const std::string& name) const; Variable* FindVar(const std::string& name) const;
// return if already contains same name variable. // return if already contains same name variable.
Variable* CreateVariable(const std::string& name); Variable* NewVar(const std::string& name);
private: private:
std::shared_ptr<Scope> parent_; std::shared_ptr<Scope> parent_;
...@@ -107,11 +107,11 @@ class Scope { ...@@ -107,11 +107,11 @@ class Scope {
``` ```
## Only scope can create a variable ## Only scope can create a variable
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`. To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `NewVar` can construct `Variable`.
## When scope destroyed, all variables inside this scope should be destroyed together ## When scope destroyed, all variables inside this scope should be destroyed together
The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together. The scope hold unique pointers for all variables. User can `FindVar` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
## Sharing a parent scope ## Sharing a parent scope
...@@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar ...@@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar
## Orthogonal interface ## Orthogonal interface
`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily. `FindVar` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `NewVar` will return a `Error` when there is a name conflict locally. Combine `FindVar` and `NewVar`, we can implement `NewVar` easily.
...@@ -49,6 +49,7 @@ message AttrProto { ...@@ -49,6 +49,7 @@ message AttrProto {
message VarProto { message VarProto {
required string name = 1; required string name = 1;
required string comment = 2; required string comment = 2;
required bool is_tensor = 3;
}; };
message OpProto { message OpProto {
......
...@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异 ...@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。 * 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习律或者对数据进行归一化处理。 主要的解决办法是减小学习律或者对数据进行归一化处理。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
pip uninstall py_paddle paddle
然后安装paddle的python环境, 在build目录下执行
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
...@@ -68,7 +68,7 @@ As a simple example, consider the following: ...@@ -68,7 +68,7 @@ As a simple example, consider the following:
1. **BLAS Dependencies(optional)** 1. **BLAS Dependencies(optional)**
CMake will search BLAS libraries from system. If not found, OpenBLAS will be downloaded, built and installed automatically. CMake will search BLAS libraries from the system. If not found, OpenBLAS will be downloaded, built and installed automatically.
To utilize preinstalled BLAS, you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`. To utilize preinstalled BLAS, you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`.
```bash ```bash
...@@ -131,9 +131,9 @@ As a simple example, consider the following: ...@@ -131,9 +131,9 @@ As a simple example, consider the following:
To build GPU version, you will need the following installed: To build GPU version, you will need the following installed:
1. a CUDA-capable GPU 1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain 2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment, The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on including the host compiler and C runtime libraries, and is therefore only supported on
...@@ -172,6 +172,7 @@ export PATH=<path to install>/bin:$PATH ...@@ -172,6 +172,7 @@ export PATH=<path to install>/bin:$PATH
# install PaddlePaddle Python modules. # install PaddlePaddle Python modules.
sudo pip install <path to install>/opt/paddle/share/wheels/*.whl sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
``` ```
## <span id="centos">Build on Centos 7</span> ## <span id="centos">Build on Centos 7</span>
### Install Dependencies ### Install Dependencies
...@@ -192,9 +193,9 @@ sudo pip install <path to install>/opt/paddle/share/wheels/*.whl ...@@ -192,9 +193,9 @@ sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
To build GPU version, you will need the following installed: To build GPU version, you will need the following installed:
1. a CUDA-capable GPU 1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain 2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment, The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on including the host compiler and C runtime libraries, and is therefore only supported on
...@@ -222,7 +223,7 @@ mkdir build && cd build ...@@ -222,7 +223,7 @@ mkdir build && cd build
``` ```
Finally, you can build and install PaddlePaddle: Finally, you can build and install PaddlePaddle:
```bash ```bash
# you can add build option here, such as: # you can add build option here, such as:
cmake3 .. -DCMAKE_INSTALL_PREFIX=<path to install> cmake3 .. -DCMAKE_INSTALL_PREFIX=<path to install>
......
...@@ -3,6 +3,43 @@ PaddlePaddle的Docker容器使用方式 ...@@ -3,6 +3,43 @@ PaddlePaddle的Docker容器使用方式
PaddlePaddle目前唯一官方支持的运行的方式是Docker容器。因为Docker能在所有主要操作系统(包括Linux,Mac OS X和Windows)上运行。 请注意,您需要更改 `Dockers设置 <https://github.com/PaddlePaddle/Paddle/issues/627>`_ 才能充分利用Mac OS X和Windows上的硬件资源。 PaddlePaddle目前唯一官方支持的运行的方式是Docker容器。因为Docker能在所有主要操作系统(包括Linux,Mac OS X和Windows)上运行。 请注意,您需要更改 `Dockers设置 <https://github.com/PaddlePaddle/Paddle/issues/627>`_ 才能充分利用Mac OS X和Windows上的硬件资源。
Docker使用入门
------------------------------
几个基础的概念帮助理解和使用Docker:
- *镜像*:一个Docker镜像是一个打包好的软件。它包含了这个软件本身和它所依赖的运行环境。PaddlePaddle的Docker镜像就包含了PaddlePaddle的Python库以及其依赖的多个Python库。这样我们可以直接在Docker中运行需要的程序而不需要安装后在执行。可以执行:
.. code-block:: bash
docker images
来列出当前系统中的所有镜像,同样可以执行:
.. code-block:: bash
docker pull paddlepaddle/paddle:0.10.0
来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。
- *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。
实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。
可以执行:
.. code-block:: bash
docker run paddlepaddle/paddle:0.10.0
来使用一个镜像启动一个容器。
- 默认情况下,Docker容器会运行在独立的文件系统空间之上,我们无法在Docker容器中
访问到主机上的文件。可以通过*挂载Volume*的方式,将主机上的文件或目录挂载到
Docker容器中。下面的命令把当前目录挂载到了容器中的 /data 目录下,容器使用
debian镜像,并且启动后执行 :code:`ls /data`。
.. code-block:: bash
docker run --rm -v $(pwd):/data debian ls /data
PaddlePaddle发布的Docker镜像使用说明 PaddlePaddle发布的Docker镜像使用说明
------------------------------ ------------------------------
...@@ -12,11 +49,11 @@ PaddlePaddle需要的所有编译工具。把编译出来的PaddlePaddle也打 ...@@ -12,11 +49,11 @@ PaddlePaddle需要的所有编译工具。把编译出来的PaddlePaddle也打
像,称为生产镜像,里面涵盖了PaddlePaddle运行所需的所有环境。每次 像,称为生产镜像,里面涵盖了PaddlePaddle运行所需的所有环境。每次
PaddlePaddle发布新版本的时候都会发布对应版本的生产镜像以及开发镜像。运 PaddlePaddle发布新版本的时候都会发布对应版本的生产镜像以及开发镜像。运
行镜像包括纯CPU版本和GPU版本以及其对应的非AVX版本。我们会在 行镜像包括纯CPU版本和GPU版本以及其对应的非AVX版本。我们会在
`dockerhub.com <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_ 提供最新 `dockerhub.com <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_
的Docker镜像,可以在"tags"标签下找到最新的Paddle镜像版本。为了方便在国 和国内镜像`docker.paddlepaddle.org` 提供最新
内的开发者下载Docker镜像,我们提供了国内的镜像服务器供大家使用。如果您 的Docker镜像,可以在"tags"标签下找到最新的Paddle镜像版本。
在国内,请把文档里命令中的paddlepaddle/paddle替换成
docker.paddlepaddle.org/paddle。 **注意:为了方便在国内的开发者下载Docker镜像,我们提供了国内的镜像服务器供大家使用。如果您在国内,请把文档里命令中的paddlepaddle/paddle替换成docker.paddlepaddle.org/paddle。**
1. 开发镜像::code:`paddlepaddle/paddle:0.10.0-dev` 1. 开发镜像::code:`paddlepaddle/paddle:0.10.0-dev`
...@@ -37,13 +74,13 @@ docker.paddlepaddle.org/paddle。 ...@@ -37,13 +74,13 @@ docker.paddlepaddle.org/paddle。
.. code-block:: bash .. code-block:: bash
docker run -it --rm paddlepaddle/paddle:0.10.0-dev /bin/bash docker run -it --rm -v $(pwd):/paddle paddlepaddle/paddle:0.10.0-dev /bin/bash
或者,可以以后台进程方式运行容器: 或者,可以以后台进程方式运行容器:
.. code-block:: bash .. code-block:: bash
docker run -d -p 2202:22 -p 8888:8888 paddledev/paddle:0.10.0-dev docker run -d -p 2202:22 -p 8888:8888 -v $(pwd):/paddle paddlepaddle/paddle:0.10.0-dev /usr/sbin/sshd -D
然后用密码 :code:`root` SSH进入容器: 然后用密码 :code:`root` SSH进入容器:
...@@ -68,6 +105,8 @@ docker.paddlepaddle.org/paddle。 ...@@ -68,6 +105,8 @@ docker.paddlepaddle.org/paddle。
如果输出是No,就需要选择使用no-AVX的镜像 如果输出是No,就需要选择使用no-AVX的镜像
**注:在0.10.0之后的版本,PaddlePaddle都可以自动判断硬件是否支持AVX,所以无需判断AVX即可使用**
以上方法在GPU镜像里也能用,只是请不要忘记提前在物理机上安装GPU最新驱动。 以上方法在GPU镜像里也能用,只是请不要忘记提前在物理机上安装GPU最新驱动。
为了保证GPU驱动能够在镜像里面正常运行,我们推荐使用[nvidia-docker](https://github.com/NVIDIA/nvidia-docker)来运行镜像。 为了保证GPU驱动能够在镜像里面正常运行,我们推荐使用[nvidia-docker](https://github.com/NVIDIA/nvidia-docker)来运行镜像。
......
...@@ -63,12 +63,35 @@ CPU-only version and a CUDA GPU version and their no-AVX versions. ...@@ -63,12 +63,35 @@ CPU-only version and a CUDA GPU version and their no-AVX versions.
We put the docker images on `dockerhub.com We put the docker images on `dockerhub.com
<https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_. You can find the <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_. You can find the
latest versions under "tags" tab at dockerhub.com. If you are in latest versions under "tags" tab at dockerhub.com.
China, you can use our Docker image registry mirror to speed up the
download process. To use it, please replace all paddlepaddle/paddle in
the commands to docker.paddlepaddle.org/paddle.
1. Production images, this image might have multiple variants: ** NOTE: If you are in China, you can use our Docker image registry mirror to speed up the download process. To use it, please replace all paddlepaddle/paddle in the commands to docker.paddlepaddle.org/paddle.**
1. development image :code:`paddlepaddle/paddle:<version>-dev`
This image has packed related develop tools and runtime
environment. Users and developers can use this image instead of
their own local computer to accomplish development, build,
releasing, document writing etc. While different version of paddle
may depends on different version of libraries and tools, if you
want to setup a local environment, you must pay attention to the
versions. The development image contains:
- gcc/clang
- nvcc
- Python
- sphinx
- woboq
- sshd
Many developers use servers with GPUs, they can use ssh to login to
the server and run :code:`docker exec` to enter the docker
container and start their work. Also they can start a development
docker image with SSHD service, so they can login to the container
and start work.
2. Production images, this image might have multiple variants:
- GPU/AVX::code:`paddlepaddle/paddle:<version>-gpu` - GPU/AVX::code:`paddlepaddle/paddle:<version>-gpu`
- GPU/no-AVX::code:`paddlepaddle/paddle:<version>-gpu-noavx` - GPU/no-AVX::code:`paddlepaddle/paddle:<version>-gpu-noavx`
...@@ -84,7 +107,7 @@ the commands to docker.paddlepaddle.org/paddle. ...@@ -84,7 +107,7 @@ the commands to docker.paddlepaddle.org/paddle.
if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi
**NOTE:versions after 0.10.0 will automatically detect system AVX support, so manual detect is not needed in this case.**
To run the CPU-only image as an interactive container: To run the CPU-only image as an interactive container:
.. code-block:: bash .. code-block:: bash
...@@ -103,29 +126,6 @@ the commands to docker.paddlepaddle.org/paddle. ...@@ -103,29 +126,6 @@ the commands to docker.paddlepaddle.org/paddle.
nvidia-docker run -it --rm paddlepaddle/paddle:0.10.0-gpu /bin/bash nvidia-docker run -it --rm paddlepaddle/paddle:0.10.0-gpu /bin/bash
2. development image :code:`paddlepaddle/paddle:<version>-dev`
This image has packed related develop tools and runtime
environment. Users and developers can use this image instead of
their own local computer to accomplish development, build,
releasing, document writing etc. While different version of paddle
may depends on different version of libraries and tools, if you
want to setup a local environment, you must pay attention to the
versions. The development image contains:
- gcc/clang
- nvcc
- Python
- sphinx
- woboq
- sshd
Many developers use servers with GPUs, they can use ssh to login to
the server and run :code:`docker exec` to enter the docker
container and start their work. Also they can start a development
docker image with SSHD service, so they can login to the container
and start work.
Train Model Using Python API Train Model Using Python API
---------------------------- ----------------------------
......
...@@ -13,22 +13,18 @@ ...@@ -13,22 +13,18 @@
# serve to show the default. # serve to show the default.
import sys import sys
import os, subprocess import os, subprocess
sys.path.insert(0, os.path.abspath('@PADDLE_SOURCE_DIR@/python'))
import shlex import shlex
from recommonmark import parser, transform from recommonmark import parser, transform
try: import paddle
import py_paddle import paddle.v2
import paddle
import paddle.v2
except ImportError:
print("Must install paddle python package before generating documentation")
sys.exit(1)
MarkdownParser = parser.CommonMarkParser MarkdownParser = parser.CommonMarkParser
AutoStructify = transform.AutoStructify AutoStructify = transform.AutoStructify
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
templates_path = ["@PROJ_ROOT@/doc_theme/templates"] templates_path = ["@PADDLE_SOURCE_DIR@/doc_theme/templates"]
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
...@@ -124,7 +120,7 @@ html_theme = 'sphinx_rtd_theme' ...@@ -124,7 +120,7 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['@PROJ_ROOT@/doc_theme/static'] html_static_path = ['@PADDLE_SOURCE_DIR@/doc_theme/static']
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc' htmlhelp_basename = project + 'doc'
......
...@@ -13,15 +13,11 @@ ...@@ -13,15 +13,11 @@
# serve to show the default. # serve to show the default.
import sys import sys
import os, subprocess import os, subprocess
sys.path.insert(0, os.path.abspath('@PADDLE_SOURCE_DIR@/python'))
import shlex import shlex
from recommonmark import parser, transform from recommonmark import parser, transform
try: import paddle
import py_paddle import paddle.v2
import paddle
import paddle.v2
except ImportError:
print("Must install paddle python package before generating documentation")
sys.exit(1)
MarkdownParser = parser.CommonMarkParser MarkdownParser = parser.CommonMarkParser
...@@ -29,7 +25,7 @@ AutoStructify = transform.AutoStructify ...@@ -29,7 +25,7 @@ AutoStructify = transform.AutoStructify
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
templates_path = ["@PROJ_ROOT@/doc_theme/templates"] templates_path = ["@PADDLE_SOURCE_DIR@/doc_theme/templates"]
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
...@@ -124,7 +120,7 @@ html_theme = 'sphinx_rtd_theme' ...@@ -124,7 +120,7 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['@PROJ_ROOT@/doc_theme/static'] html_static_path = ['@PADDLE_SOURCE_DIR@/doc_theme/static']
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc' htmlhelp_basename = project + 'doc'
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
...@@ -5,12 +19,15 @@ import ( ...@@ -5,12 +19,15 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
...@@ -20,11 +37,18 @@ func main() { ...@@ -20,11 +37,18 @@ func main() {
port := flag.Int("port", 8080, "port of the master server.") port := flag.Int("port", 8080, "port of the master server.")
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
level, e := log.ParseLevel(*logLevel)
candy.Must(e)
log.SetLevel(level)
if *endpoints == "" { if *endpoints == "" {
log.Warningln("-endpoints not set, fault tolerance not be enabled.") log.Warningln("-endpoints not set, fault tolerance not be enabled.")
} }
...@@ -46,6 +70,20 @@ func main() { ...@@ -46,6 +70,20 @@ func main() {
store = &master.InMemStore{} store = &master.InMemStore{}
} }
shutdown := func() {
log.Infoln("shutting down gracefully")
err := store.Shutdown()
if err != nil {
log.Errorln(err)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
...@@ -62,8 +100,12 @@ func main() { ...@@ -62,8 +100,12 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
err = http.Serve(l, nil) go func() {
if err != nil { err = http.Serve(l, nil)
log.Fatal(err) if err != nil {
} log.Fatal(err)
}
}()
<-c
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"time" "time"
...@@ -16,10 +32,11 @@ import ( ...@@ -16,10 +32,11 @@ import (
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls") dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
...@@ -39,16 +56,34 @@ func main() { ...@@ -39,16 +56,34 @@ func main() {
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
idx, err = e.Register() idx, err = e.Register(*port)
candy.Must(err) candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil { if err != nil {
log.Errorf("Fetch checkpoint failed, %s", err) if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.")
} else {
panic(err)
}
}
}
shutdown := func() {
log.Infoln("shutting down gracefully")
sErr := e.Shutdown()
if sErr != nil {
log.Errorln(sErr)
} }
} }
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err) candy.Must(err)
...@@ -59,7 +94,11 @@ func main() { ...@@ -59,7 +94,11 @@ func main() {
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
candy.Must(err) candy.Must(err)
log.Infof("start pserver at port %d", *port) go func() {
err = http.Serve(l, nil) log.Infof("start pserver at port %d", *port)
candy.Must(err) err = http.Serve(l, nil)
candy.Must(err)
}()
<-c
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connection package connection
import ( import (
......
hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855 hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-07-11T10:04:40.786745417+08:00 updated: 2017-08-07T23:37:48.867469328Z
imports: imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: cb2a496c4ddd1c87a9f280e116649b599999ec79 version: d0d1a87aa96ae14914751d42264262cb69eda170
subpackages: subpackages:
- alarm
- auth
- auth/authpb - auth/authpb
- client
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes - etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb - etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb - mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf - name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7 version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages: subpackages:
...@@ -17,14 +108,63 @@ imports: ...@@ -17,14 +108,63 @@ imports:
- proto - proto
- name: github.com/golang/snappy - name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9 version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag - name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04 version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio - name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129 version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1 version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy - name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 1351f936d976c60a0a48d728281922cf63eafb8d
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- name: golang.org/x/net - name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2 version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages: subpackages:
...@@ -36,11 +176,15 @@ imports: ...@@ -36,11 +176,15 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: abf9c25f54453410d0c6668e519582a9e1115027 version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
repo: https://github.com/golang/sys.git
vcs: git
subpackages: subpackages:
- unix - unix
- name: golang.org/x/text - name: golang.org/x/text
version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages: subpackages:
- secure/bidirule - secure/bidirule
- transform - transform
...@@ -60,4 +204,18 @@ imports: ...@@ -60,4 +204,18 @@ imports:
- stats - stats
- tap - tap
- transport - transport
testImports: [] - name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
...@@ -6,8 +6,21 @@ import: ...@@ -6,8 +6,21 @@ import:
subpackages: subpackages:
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- embed
- etcdserver
- package: github.com/namsral/flag - package: github.com/namsral/flag
version: ^1.7.4-pre version: ^1.7.4-pre
- package: github.com/sirupsen/logrus - package: github.com/sirupsen/logrus
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy - package: github.com/topicai/candy
- package: golang.org/x/crypto
repo: https://github.com/golang/crypto.git
vcs: git
- package: golang.org/x/sys
repo: https://github.com/golang/sys.git
vcs: git
- package: golang.org/x/text
repo: https://github.com/golang/text.git
vcs: git
- package: github.com/satori/go.uuid
version: v1.1.0
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(master_test) go_test(master_test)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
go_library(paddle_master SHARED DEPS paddle_go_optimizer) go_library(paddle_master SHARED DEPS paddle_go_optimizer)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#define PADDLE_MASTER_OK 0 #define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1 #define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client; typedef int paddle_master_client;
*/ */
import "C" import "C"
...@@ -19,11 +35,9 @@ import ( ...@@ -19,11 +35,9 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client) var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client var curHandle C.paddle_master_client
...@@ -52,32 +66,32 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -52,32 +66,32 @@ func remove(client C.paddle_master_client) *master.Client {
} }
//export paddle_new_etcd_master_client //export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client { func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints) p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{ endpoints := strings.Split(p, ",")
Endpoints: strings.Split(p, ","), c, err := master.NewClient(
DialTimeout: time.Second * time.Duration(timeout), master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
}) master.WithBuffer(bufSize),
if err != nil { )
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil { if err != nil {
panic(err) panic(err)
} }
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
ch := make(chan string, 1) c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
ch <- a if err != nil {
c := master.NewClient(ch, bufSize) panic(err)
}
return add(c) return add(c)
} }
...@@ -86,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) { ...@@ -86,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client) remove(client)
} }
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset //export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client) c := get(client)
...@@ -104,23 +124,28 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -104,23 +124,28 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK return C.PADDLE_MASTER_OK
} }
// return value: // paddle_next_record gets the nexts training record.
// 0:ok //
// -1:error // returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r, err := c.NextRecord() r, err := c.NextRecord()
if err != nil { if err != nil {
// Error // NOTE: use errors to indicate pass ends
// TODO: return the type of error? if err.Error() == master.ErrAllTaskFailed.Error() ||
*record = (*C.uchar)(nullPtr) err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1 return -1
} }
if len(r) == 0 { if len(r) == 0 {
// Empty record // Empty record
*record = (*C.uchar)(nullPtr) *record = (*C.uchar)(nil)
return 0 return 0
} }
...@@ -130,6 +155,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { ...@@ -130,6 +155,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return C.int(size) return C.int(size)
} }
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free //export mem_free
func mem_free(p unsafe.Pointer) { func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so // "free" may be a better name for this function, but doing so
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
"os" "os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan record ch chan record
bufSize int
} }
type record struct { type record struct {
...@@ -19,33 +36,104 @@ type record struct { ...@@ -19,33 +36,104 @@ type record struct {
err error err error
} }
// NewClient creates a new Client. // WithBuffer sets the client to buffer the training record.
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addrCh <-chan string, bufSize int) *Client { func WithBuffer(bufSize int) func(*Client) error {
return func(c *Client) error {
if bufSize <= 0 {
return nil
}
c.bufSize = bufSize
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
if err != nil {
return err
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh) for _, opt := range opts {
go c.getRecords() err := opt(c)
return c if err != nil {
return nil, err
}
}
c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil
}
// StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
} }
func (c *Client) getRecords() { func (c *Client) getRecords(passID int) {
for { for {
t, err := c.getTask() t, err := c.getTask(passID)
if err != nil { if err != nil {
// TODO(helin): wait before move on with next if err.Error() == ErrPassBefore.Error() ||
// getTask call. err.Error() == ErrNoMoreAvailable.Error() ||
log.Errorln(err) err.Error() == ErrAllTaskFailed.Error() {
continue c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue
}
log.Errorf("getTask error: %s", err)
} }
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
if err != nil { if e != nil {
log.Errorln(err) log.Errorln(e)
continue continue
} }
...@@ -68,7 +156,10 @@ func (c *Client) getRecords() { ...@@ -68,7 +156,10 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data // We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly // instance of the task is read. This is not exactly
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
c.taskFinished(t.Meta.ID) err = c.taskFinished(t.Meta.ID)
if err != nil {
log.Errorln(err)
}
} }
} }
...@@ -98,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) { ...@@ -98,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
} }
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
// //
// SetDataset can be call multiple times from different nodes. But // After all tasks are done, another call of SetDataset will start another pass.
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error { func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil) err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
} }
// getTask gets a new task from the master server. // getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) { func (c *Client) getTask(passID int) (Task, error) {
var t Task var t Task
err := c.conn.Call("Service.GetTask", 0, &t) err := c.conn.Call("Service.GetTask", passID, &t)
return t, err return t, err
} }
...@@ -131,3 +225,11 @@ func (c *Client) NextRecord() ([]byte, error) { ...@@ -131,3 +225,11 @@ func (c *Client) NextRecord() ([]byte, error) {
r := <-c.ch r := <-c.ch
return r.r, r.err return r.r, r.err
} }
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
...@@ -40,22 +54,22 @@ func TestGetFinishTask(t *testing.T) { ...@@ -40,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
server := rpc.NewServer() server := rpc.NewServer()
err = server.Register(s) sErr = server.Register(s)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server) mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux) sErr = http.Serve(l, mux)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
}(l) }(l)
...@@ -66,11 +80,21 @@ func TestGetFinishTask(t *testing.T) { ...@@ -66,11 +80,21 @@ func TestGetFinishTask(t *testing.T) {
for i := 0; i < totalTask*chunkPerTask; i++ { for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1) w := recordio.NewWriter(f, -1, -1)
w.Write(nil) _, err = w.Write(nil)
if err != nil {
panic(err)
}
// call Close to force RecordIO writing a chunk. // call Close to force RecordIO writing a chunk.
w.Close() err = w.Close()
if err != nil {
panic(err)
}
}
err = f.Close()
if err != nil {
panic(err)
} }
f.Close()
// Manually intialize client to avoid calling c.getRecords() // Manually intialize client to avoid calling c.getRecords()
c := &Client{} c := &Client{}
...@@ -79,48 +103,56 @@ func TestGetFinishTask(t *testing.T) { ...@@ -79,48 +103,56 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1) ch := make(chan string, 1)
ch <- addr ch <- addr
go c.monitorMaster(ch) go c.monitorMaster(ch)
c.SetDataset([]string{path})
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask() task, cErr := c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("error: %v, pass: %d\n", cErr, i)
} }
tasks = append(tasks, task) tasks = append(tasks, task)
} }
_, err = c.getTask() // getting task before task finishes should return error
if err == nil { _, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i) t.Fatalf("Should get error, pass: %d\n", i)
} }
err = c.taskFinished(tasks[0].Meta.ID) cErr = c.taskFinished(tasks[0].Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", cErr, i)
} }
// call taskFailed once won't put the task to failed queue, just ensure
err = c.taskFailed(tasks[0].Meta) // the call
if err != nil { cErr = c.taskFailed(tasks[0].Meta)
t.Fatalf("Error: %v, pass: %d\n", err, i) if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
} }
tasks = tasks[1:] tasks = tasks[1:]
task, err := c.getTask() _, cErr = c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatal(err) t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
} }
tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.taskFinished(task.Meta.ID) cErr = c.taskFinished(task.Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatal(cErr)
} }
} }
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i) checkOnePass(i)
} }
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master_test package master_test
import ( import (
...@@ -6,8 +20,10 @@ import ( ...@@ -6,8 +20,10 @@ import (
"net/http" "net/http"
"net/rpc" "net/rpc"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -15,6 +31,18 @@ import ( ...@@ -15,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) { func TestNextRecord(t *testing.T) {
const ( const (
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
...@@ -31,7 +59,7 @@ func TestNextRecord(t *testing.T) { ...@@ -31,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -55,32 +83,67 @@ func TestNextRecord(t *testing.T) { ...@@ -55,32 +83,67 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
w := recordio.NewWriter(f, -1, -1) w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
w.Write([]byte{byte(i)}) _, err = w.Write([]byte{byte(i)})
if err != nil {
panic(err)
}
}
err = w.Close()
if err != nil {
panic(err)
} }
w.Close()
f.Close()
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}
if len(r) != 1 { err = f.Close()
t.Fatal(pass, i, "Length should be 1.", r) if err != nil {
panic(err)
}
// start several client to test task fetching
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
} }
e = c.SetDataset([]string{path})
if e != nil {
panic(e)
}
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
if received[r[0]] { received := make(map[byte]bool)
t.Fatal(pass, i, "Received duplicate.", received, r) taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
} }
received[r[0]] = true }()
}
} }
wg.Wait()
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
...@@ -25,15 +39,12 @@ type EtcdClient struct { ...@@ -25,15 +39,12 @@ type EtcdClient struct {
statePath string statePath string
client *clientv3.Client client *clientv3.Client
lock *concurrency.Mutex lock *concurrency.Mutex
sess *concurrency.Session
} }
// NewEtcdClient creates a new EtcdClient. // NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints) log.Debugf("Connecting to etcd at %v", endpoints)
// TODO(helin): gracefully shutdown etcd store. Becuase etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
DialTimeout: dialTimeout, DialTimeout: dialTimeout,
...@@ -53,14 +64,14 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -53,14 +64,14 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath) log.Infof("Trying to acquire lock at %s.", lockPath)
err = lock.Lock(context.TODO()) err = lock.Lock(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("Successfully acquired lock at %s.", lockPath) log.Infof("Successfully acquired lock at %s.", lockPath)
put := clientv3.OpPut(addrPath, string(addr)) put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -75,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -75,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
statePath: statePath, statePath: statePath,
client: cli, client: cli,
lock: lock, lock: lock,
sess: sess,
} }
return e, nil return e, nil
...@@ -143,9 +155,24 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -143,9 +155,24 @@ func (e *EtcdClient) Load() ([]byte, error) {
return state, nil return state, nil
} }
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
err := e.sess.Close()
newErr := e.client.Close()
if newErr != nil {
if err == nil {
err = newErr
} else {
log.Errorln(newErr)
}
}
return err
}
// GetKey gets the value by the specify key. // GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key) resp, err := c.Get(ctx, key)
cancel() cancel()
if err != nil { if err != nil {
...@@ -159,8 +186,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { ...@@ -159,8 +186,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return string(v), nil return string(v), nil
} }
// WatchKey watches the specify key and send to valChan if there is some event. // watchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) { func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key) rch := c.Watch(context.Background(), key)
for wresp := range rch { for wresp := range rch {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "sync" import "sync"
// InMemStore is an in memory implementation of Store interface. // InMemStore is an in memory implementation of Store interface.
// //
// It does not tolerate the fault that casues the program to crash. // It does not tolerate the fault that causes the program to crash.
type InMemStore struct { type InMemStore struct {
mu sync.Mutex mu sync.Mutex
buf []byte buf []byte
...@@ -26,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) { ...@@ -26,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
return m.buf, nil return m.buf, nil
} }
// Shutdown shuts down the in mem store.
func (m *InMemStore) Shutdown() error {
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
...@@ -5,6 +19,7 @@ import ( ...@@ -5,6 +19,7 @@ import (
"compress/gzip" "compress/gzip"
"encoding/gob" "encoding/gob"
"errors" "errors"
"math/rand"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
...@@ -19,10 +34,23 @@ const ( ...@@ -19,10 +34,23 @@ const (
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
) )
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state. // Store is the interface for save and load the master state.
type Store interface { type Store interface {
Save([]byte) error Save([]byte) error
Load() ([]byte, error) Load() ([]byte, error)
Shutdown() error
} }
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
...@@ -49,11 +77,12 @@ type taskEntry struct { ...@@ -49,11 +77,12 @@ type taskEntry struct {
NumFailure int NumFailure int
} }
type taskQueues struct { type masterState struct {
Todo []taskEntry Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry Done []taskEntry
Failed []taskEntry Failed []taskEntry
CurPass int
} }
// Service is the master server service. // Service is the master server service.
...@@ -61,16 +90,26 @@ type Service struct { ...@@ -61,16 +90,26 @@ type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
failureMax int failureMax int
ready chan struct{}
store Store store Store
mu sync.Mutex ready chan struct{}
initDone bool initDone bool
taskQueues taskQueues
mu sync.Mutex
// State to be persisted to snapshot.
state masterState
// The trainer that is currently saving model. This state is
// transient, does not need to be persisted to snapshot.
savingTrainer string
} }
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0 // generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 { if chunksPerTask <= 0 {
chunksPerTask = 1 chunksPerTask = 1
} }
...@@ -80,7 +119,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -80,7 +119,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for i, c := range chunks { for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id cur.Task.Meta.ID = id
id++ counter++
id = timestamp + randStart + counter
result = append(result, cur) result = append(result, cur)
cur.Task.Chunks = nil cur.Task.Chunks = nil
} }
...@@ -102,8 +142,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur ...@@ -102,8 +142,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur
s.chunksPerTask = chunksPerTask s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.failureMax = failureMax s.failureMax = failureMax
s.taskQueues = taskQueues{} s.state = masterState{}
s.taskQueues.Pending = make(map[int]taskEntry) s.state.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{}) s.ready = make(chan struct{})
s.store = store s.store = store
recovered, err := s.recover() recovered, err := s.recover()
...@@ -141,7 +181,7 @@ func (s *Service) recover() (bool, error) { ...@@ -141,7 +181,7 @@ func (s *Service) recover() (bool, error) {
} }
dec := gob.NewDecoder(gr) dec := gob.NewDecoder(gr)
var tqs taskQueues var tqs masterState
err = dec.Decode(&tqs) err = dec.Decode(&tqs)
if err != nil { if err != nil {
return false, err return false, err
...@@ -154,13 +194,18 @@ func (s *Service) recover() (bool, error) { ...@@ -154,13 +194,18 @@ func (s *Service) recover() (bool, error) {
log.Errorln(err) log.Errorln(err)
} }
s.taskQueues = tqs s.state = tqs
log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.")
for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
}
return true, nil return true, nil
} }
// snapshot *must* be called with s.mu being held. // snapshot *must* be called with s.mu being held.
func (s *Service) snapshot() error { func (s *Service) snapshot() error {
// TOOD(helin): etcd request has a size limit, so the snapshot // TODO(helin): etcd request has a size limit, so the snapshot
// size is limited by the max request size. We should either // size is limited by the max request size. We should either
// divide the snapshot into smaller chunks and save under // divide the snapshot into smaller chunks and save under
// different keys, or configure the request size to be big // different keys, or configure the request size to be big
...@@ -169,7 +214,7 @@ func (s *Service) snapshot() error { ...@@ -169,7 +214,7 @@ func (s *Service) snapshot() error {
var buf bytes.Buffer var buf bytes.Buffer
gw := gzip.NewWriter(&buf) gw := gzip.NewWriter(&buf)
enc := gob.NewEncoder(gw) enc := gob.NewEncoder(gw)
err := enc.Encode(s.taskQueues) err := enc.Encode(s.state)
if err != nil { if err != nil {
return err return err
} }
...@@ -215,6 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -215,6 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
} }
count := index.NumChunks() count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunk := Chunk{ chunk := Chunk{
Path: path, Path: path,
...@@ -231,7 +277,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -231,7 +277,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
// //
// SetDataset can be call multiple times. But only the first call will // SetDataset can be call multiple times. But only the first call will
// be honored. // be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error { func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 { if len(globPaths) == 0 {
return errors.New("no dataset specified") return errors.New("no dataset specified")
} }
...@@ -250,19 +296,20 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -250,19 +296,20 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return err return err
} }
s.taskQueues.Todo = partition(chunks, s.chunksPerTask) s.state.Todo = partition(chunks, s.chunksPerTask)
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
return err return err
} }
close(s.ready) close(s.ready)
s.initDone = true s.initDone = true
return nil return nil
} }
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) { func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch { if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the // new epoch, task launched after the
...@@ -277,17 +324,17 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { ...@@ -277,17 +324,17 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
} }
}() }()
delete(s.taskQueues.Pending, t.Task.Meta.ID) delete(s.state.Pending, t.Task.Meta.ID)
t.NumFailure++ t.NumFailure++
if t.NumFailure > s.failureMax { if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t) s.state.Failed = append(s.state.Failed, t)
return return
} }
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.state.Todo = append(s.state.Todo, t)
return return
} }
...@@ -296,7 +343,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -296,7 +343,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID] t, ok := s.state.Pending[taskID]
if !ok { if !ok {
return return
} }
...@@ -308,51 +355,45 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -308,51 +355,45 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
// must be called with lock held. // must be called with lock held.
func (s *Service) logFields() log.Fields { func (s *Service) logFields() log.Fields {
return log.Fields{ return log.Fields{
"todoLen": len(s.taskQueues.Todo), "todoLen": len(s.state.Todo),
"pendingLen": len(s.taskQueues.Pending), "pendingLen": len(s.state.Pending),
"doneLen": len(s.taskQueues.Done), "doneLen": len(s.state.Done),
"failedLen": len(s.taskQueues.Failed), "failedLen": len(s.state.Failed),
"curPass": s.state.CurPass,
} }
} }
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error { // passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select { select {
case <-s.ready: case <-s.ready:
} }
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if passID < s.state.CurPass {
return ErrPassBefore
}
if passID > s.state.CurPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.taskQueues.Todo) == 0 { if len(s.state.Todo) == 0 {
if len(s.taskQueues.Done) == 0 { if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 { log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
err := errors.New("all task failed") return ErrAllTaskFailed
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
} }
s.taskQueues.Todo = s.taskQueues.Done log.WithFields(s.logFields()).Warningln("No more available task.")
s.taskQueues.Done = nil return ErrNoMoreAvailable
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
} }
t := s.taskQueues.Todo[0] t := s.state.Todo[0]
t.Task.Meta.Epoch++ t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] s.state.Todo = s.state.Todo[1:]
s.taskQueues.Pending[t.Task.Meta.ID] = t s.state.Pending[t.Task.Meta.ID] = t
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
return err return err
...@@ -374,7 +415,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -374,7 +415,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID] t, ok := s.state.Pending[taskID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return nil return nil
...@@ -382,15 +423,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -382,15 +423,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
// task finished, reset timeout // task finished, reset timeout
t.NumFailure = 0 t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.state.Done = append(s.state.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.state.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { // increase master side pass count if all tasks finished
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.") s.state.CurPass++
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.state.Todo = append(s.state.Done, s.state.Failed...)
s.taskQueues.Done = nil s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.state.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass)
} }
err := s.snapshot() err := s.snapshot()
...@@ -409,7 +453,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { ...@@ -409,7 +453,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[meta.ID] t, ok := s.state.Pending[meta.ID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta) log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil return nil
...@@ -418,3 +462,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { ...@@ -418,3 +462,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.processFailedTask(t, meta.Epoch) s.processFailedTask(t, meta.Epoch)
return nil return nil
} }
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "testing" import "testing"
...@@ -30,7 +44,8 @@ func TestPartionIndex(t *testing.T) { ...@@ -30,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 20)
for i := range ts { for i := range ts {
if ts[i].Task.Meta.ID != i { // test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i) t.Error(ts[i], i)
} }
} }
......
package master_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
t.Fatal(err)
}
_, err = master.NewService(store, 10, 10, 3)
if err != nil {
t.Fatal(err)
}
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: 3 * time.Second,
})
if err != nil {
t.Fatal(err)
}
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
if err != nil {
t.Fatal(err)
}
if err := cli.Close(); err != nil {
t.Fatal(err)
}
// test master process registry itself into etcd server.
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer) go_test(pserver_test DEPS paddle_go_optimizer)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer) go_test(pserver_client_test DEPS paddle_go_optimizer)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m) target_link_libraries(paddle_go_optimizer stdc++ m)
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
...@@ -34,7 +48,6 @@ import ( ...@@ -34,7 +48,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client) var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client var curHandle C.paddle_pserver_client
...@@ -42,10 +55,10 @@ var curHandle C.paddle_pserver_client ...@@ -42,10 +55,10 @@ var curHandle C.paddle_pserver_client
func add(c *client.Client) C.paddle_pserver_client { func add(c *client.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle cli := curHandle
curHandle++ curHandle++
handleMap[client] = c handleMap[cli] = c
return client return cli
} }
func get(client C.paddle_pserver_client) *client.Client { func get(client C.paddle_pserver_client) *client.Client {
...@@ -63,7 +76,7 @@ func remove(client C.paddle_pserver_client) *client.Client { ...@@ -63,7 +76,7 @@ func remove(client C.paddle_pserver_client) *client.Client {
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr { if p == nil {
return nil return nil
} }
...@@ -77,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -77,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -101,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli ...@@ -101,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client { func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters) addr := C.GoString(etcdEndpoints)
addr := C.GoString(etcd_endpoints) etcdClient := client.NewEtcd(addr)
etcd_client := client.NewEtcd(addr) c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0))
return add(c) return add(c)
} }
...@@ -114,30 +130,41 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) { ...@@ -114,30 +130,41 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client) remove(client)
} }
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { selected, err := c.BeginInitParams()
if err != nil {
panic(err)
}
if selected {
return 1 return 1
} }
return C.PSERVER_OK return 0
} }
//export paddle_init_param //export paddle_init_param
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int { func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, paramConfig unsafe.Pointer, configLen C.int) C.int {
et := pserver.ElementType(param.element_type) et := pserver.ElementType(param.element_type)
name := C.GoString(param.name) name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
pc := pserver.ParameterWithConfig{ pc := pserver.ParameterWithConfig{
Param: pserver.Parameter{Name: name, ElementType: et, Content: content}, Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
Config: cArrayToSlice(param_config, int(config_len)), Config: cArrayToSlice(paramConfig, int(configLen)),
} }
c := get(client) c := get(client)
err := c.InitParam(pc) err := c.InitParam(pc)
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name) log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Errorln(err)
...@@ -153,7 +180,7 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int { ...@@ -153,7 +180,7 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningln("parameters already initialized, treat paddle_finish_init_params as sucessful.") log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK return C.PSERVER_OK
} }
...@@ -223,12 +250,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -223,12 +250,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
p := ps[i] p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nullPtr { if unsafe.Pointer(param) == nil {
log.Errorln("must pre-allocate parameter.") log.Errorln("must pre-allocate parameter.")
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
if unsafe.Pointer(param.content) != nullPtr { if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) { if int(param.content_len) != len(p.Content) {
log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
return C.PSERVER_ERROR return C.PSERVER_ERROR
...@@ -243,17 +270,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -243,17 +270,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return C.PSERVER_OK return C.PSERVER_OK
} }
//export paddle_save_model
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
if err != nil {
log.Errorln(err)
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
func main() {} // Required but ignored func main() {} // Required but ignored
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer)
add_style_check_target(test_cclient test_cclient.c) add_style_check_target(test_cclient test_cclient.c)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
...@@ -97,9 +111,5 @@ retry: ...@@ -97,9 +111,5 @@ retry:
getParams(c); getParams(c);
} }
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0; return 0;
} }
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master
import os
import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoints = "http://" + etcd_ip + ":2379"
print "etcd endpoints: ", etcd_endpoints
def main(): def main():
...@@ -20,19 +28,20 @@ def main(): ...@@ -20,19 +28,20 @@ def main():
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer of new remote updater to pserver # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
is_local=False, is_local=False,
pserver_spec="localhost:3000") pserver_spec=etcd_endpoints,
use_etcd=True)
# event_handler to print training and testing info # event_handler to print training and testing info
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % ( print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, event.cost)
...@@ -47,10 +56,14 @@ def main(): ...@@ -47,10 +56,14 @@ def main():
print "Test %d, %.2f" % (event.pass_id, result.cost) print "Test %d, %.2f" % (event.pass_id, result.cost)
# training # training
# NOTE: use uci_housing.train() as reader for non-paddlecloud training
trainer.train( trainer.train(
reader=paddle.batch( reader=paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
uci_housing.train(), buf_size=500), cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2), batch_size=2),
feeding={'x': 0, feeding={'x': 0,
'y': 1}, 'y': 1},
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
...@@ -13,9 +27,13 @@ import ( ...@@ -13,9 +27,13 @@ import (
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers. // Selector selects if the client should initialize parameters and
// reports the initialization process done.
type Selector interface { type Selector interface {
Select() bool // Select selects if the client should initialize parameter servers.
Select() (bool, error)
// Done indicates the initialization process is done.
Done() error
} }
// Server is the identification of a parameter Server. // Server is the identification of a parameter Server.
...@@ -101,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -101,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// servers. Other trainers will be blocked until the initialization is // servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from // done, and they need to get the initialized parameters from
// parameter servers using GetParams. // parameter servers using GetParams.
func (c *Client) BeginInitParams() bool { func (c *Client) BeginInitParams() (bool, error) {
return c.sel.Select() return c.sel.Select()
} }
...@@ -205,35 +223,9 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) { ...@@ -205,35 +223,9 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return ps, nil return ps, nil
} }
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
err := p.Call("Service.Save", path, nil)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 { func strHash(s string) uint32 {
h := fnv.New32a() h := fnv.New32a()
h.Write([]byte(s)) _, _ = h.Write([]byte(s))
return h.Sum32() return h.Sum32()
} }
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client_test package client_test
import ( import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -43,7 +59,7 @@ func initClient() [numPserver]int { ...@@ -43,7 +59,7 @@ func initClient() [numPserver]int {
go func(l net.Listener) { go func(l net.Listener) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -77,21 +93,43 @@ func initEtcdClient() { ...@@ -77,21 +93,43 @@ func initEtcdClient() {
log.Errorf("err %v", err) log.Errorf("err %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired) _, err = client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath) if err != nil {
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver)) panic(err)
}
_, err = client.Delete(ctx, pserver.PsPath)
if err != nil {
panic(err)
}
_, err = client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
if err != nil {
panic(err)
}
ports := initClient() ports := initClient()
for i := 0; i < numPserver; i++ { for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i])) _, err = client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
if err != nil {
panic(err)
}
} }
cancel() cancel()
client.Close() err = client.Close()
if err != nil {
panic(err)
}
} }
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -100,27 +138,38 @@ func (l lister) List() []client.Server { ...@@ -100,27 +138,38 @@ func (l lister) List() []client.Server {
return l return l
} }
func ClientTest(t *testing.T, c *client.Client) { func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams() selected, err := c.BeginInitParams()
if err != nil {
t.Fatal(err)
}
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
const numParameter = 100 const numParameter = 1000
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb") config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
var wg sync.WaitGroup
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
var p pserver.Parameter wg.Add(1)
p.Name = "p_" + strconv.Itoa(i) go func(i int) {
p.ElementType = pserver.Float32 var p pserver.Parameter
p.Content = make([]byte, (i+1)*100) p.Name = "p_" + strconv.Itoa(i)
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}) p.ElementType = pserver.Float32
if err != nil { p.Content = make([]byte, (i+1)*100)
t.Fatal(err) err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
} if err != nil {
t.Fatal(err)
}
wg.Done()
}(i)
} }
wg.Wait()
err = c.FinishInitParams() err = c.FinishInitParams()
if err != nil { if err != nil {
...@@ -128,7 +177,7 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -128,7 +177,7 @@ func ClientTest(t *testing.T, c *client.Client) {
} }
var grads []pserver.Gradient var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ { for i := 0; i < numParameter; i++ {
var g pserver.Gradient var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i) g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32 g.ElementType = pserver.Float32
...@@ -136,9 +185,31 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -136,9 +185,31 @@ func ClientTest(t *testing.T, c *client.Client) {
grads = append(grads, g) grads = append(grads, g)
} }
err = c.SendGrads(grads) const paramPerGroup = 10
if err != nil { const numGroups = numParameter / paramPerGroup
t.Fatal(err)
// shuffle send grads order
for i := range grads {
j := rand.Intn(i + 1)
grads[i], grads[j] = grads[j], grads[i]
}
for i := 0; i < numGroups; i++ {
var gs []pserver.Gradient
if i == numGroups-1 {
gs = grads[i*paramPerGroup:]
} else {
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(gs []pserver.Gradient) {
err := c.SendGrads(gs)
if err != nil {
t.Fatal(err)
}
wg.Done()
}(gs)
} }
names := make([]string, numParameter) names := make([]string, numParameter)
...@@ -146,20 +217,35 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -146,20 +217,35 @@ func ClientTest(t *testing.T, c *client.Client) {
names[i] = "p_" + strconv.Itoa(i) names[i] = "p_" + strconv.Itoa(i)
} }
params, err := c.GetParams(names) for i := 0; i < numGroups; i++ {
if err != nil { var ns []string
t.Fatal(err) if i == numGroups-1 {
} ns = names[i*paramPerGroup:]
} else {
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
}
if len(names) != len(params) { wg.Add(1)
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) go func(ns []string) {
} params, err := c.GetParams(ns)
if err != nil {
t.Fatal(err)
}
for i := range params { if len(ns) != len(params) {
if names[i] != params[i].Name { t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name) }
}
for i := range params {
if ns[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
}
}
wg.Done()
}(ns)
} }
wg.Wait()
} }
func TestNativeClient(t *testing.T) { func TestNativeClient(t *testing.T) {
...@@ -169,13 +255,14 @@ func TestNativeClient(t *testing.T) { ...@@ -169,13 +255,14 @@ func TestNativeClient(t *testing.T) {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])} servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
} }
c1 := client.NewClient(lister(servers), len(servers), selector(true)) c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1) testClient(t, c1)
} }
// TODO: tmperary disable etcdClient test for dependency of etcd) // EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
func EtcdClient(t *testing.T) { func EtcdClient(t *testing.T) {
initEtcdClient() initEtcdClient()
etcdClient := client.NewEtcd(etcdEndpoints) etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true)) c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2) testClient(t, c2)
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
"context" "context"
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
DefaultEtcdTimeout time.Duration = 5 * time.Second defaultEtcdTimeout time.Duration = 5 * time.Second
initLockPath = "/init_ps/lock"
initDonePath = "/init_ps/done"
initDoneVal = "1"
) )
// EtcdClient is used by pserver client that is a part of trainer process. // Etcd is used by pserver client that is a part of trainer process.
// TODO: // TODO:
// 1. add watcher to watch the change state of pservers) // 1. add watcher to watch the change state of pservers.
// 1. add etcd lock) type Etcd struct {
type EtcdClient struct {
client *clientv3.Client client *clientv3.Client
timeout time.Duration timeout time.Duration
endpoints []string endpoints []string
lock *concurrency.Mutex
} }
// Desired read ps desired number from etcd. // Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int { func (e *Etcd) Desired() int {
var psDesired int var psDesired int
for { for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired) resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...") log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %s invalid %v", psDesired, err) log.Errorf("psDesired %d invalid %v", psDesired, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -59,26 +80,26 @@ func (p *EtcdClient) Desired() int { ...@@ -59,26 +80,26 @@ func (p *EtcdClient) Desired() int {
} }
// List return the pserver list read from etcd. // List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server { func (e *Etcd) List() []Server {
psDesired := p.Desired() psDesired := e.Desired()
servers := make([]Server, psDesired) servers := make([]Server, psDesired)
for { for {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey) resp, err := e.client.Get(ctx, psKey)
cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...") log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -86,10 +107,10 @@ func (p *EtcdClient) List() []Server { ...@@ -86,10 +107,10 @@ func (p *EtcdClient) List() []Server {
// TODO(Longfei) check the ps address // TODO(Longfei) check the ps address
if psAddr == "" { if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey) log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Infof("got value (%s) for key: %s", psAddr, psKey) log.Debugf("got value (%s) for key: %s", psAddr, psKey)
servers[i].Index = i servers[i].Index = i
servers[i].Addr = psAddr servers[i].Addr = psAddr
} }
...@@ -99,27 +120,135 @@ func (p *EtcdClient) List() []Server { ...@@ -99,27 +120,135 @@ func (p *EtcdClient) List() []Server {
} }
// NewEtcd create a etcd client to return the state of pserver on etcd. // NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient { func NewEtcd(endpoints string) *Etcd {
ep := strings.Split(endpoints, ",") ep := strings.Split(endpoints, ",")
var cli *clientv3.Client var cli *clientv3.Client
var err error var err error
for { for {
cli, err = clientv3.New(clientv3.Config{ cli, err = clientv3.New(clientv3.Config{
Endpoints: ep, Endpoints: ep,
DialTimeout: DefaultEtcdTimeout, DialTimeout: defaultEtcdTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("Init etcd connection failed: %v", err) log.Errorf("Init etcd connection failed: %v", err)
time.Sleep(DefaultEtcdTimeout) time.Sleep(defaultEtcdTimeout)
continue continue
} }
break break
} }
log.Infof("Connected to etcd: %s\n", endpoints) log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{ client := &Etcd{
client: cli, client: cli,
timeout: DefaultEtcdTimeout, timeout: defaultEtcdTimeout,
endpoints: ep, endpoints: ep,
} }
return client return client
} }
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func (e *Etcd) Select() (bool, error) {
sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5))
if err != nil {
return false, err
}
lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath)
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err = lock.Lock(context.Background())
if err != nil {
return false, err
}
log.Infof("Successfully acquired lock at %s.", initLockPath)
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit()
cancel()
if err != nil {
return false, err
}
if !tresp.Succeeded {
return false, errors.New("no longer the owner of the lock")
}
resp := tresp.Responses[0].GetResponseRange()
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
log.Infoln("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
}
return false, nil
}
return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value)
}
// Done indicates the parameter initialization process is done.
func (e *Etcd) Done() error {
if e.lock == nil {
return errors.New("lock is nil, Done called unexpectedly")
}
put := clientv3.OpPut(initDonePath, initDoneVal)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
cancel()
if err != nil {
return err
}
if !tresp.Succeeded {
return errors.New("no longer the owner of the lock")
}
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
} else {
e.lock = nil
}
return nil
}
// Close closes the etcd client.
func (e *Etcd) Close() error {
var err error
if e.lock != nil {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err == nil {
e.lock = nil
}
}
cErr := e.client.Close()
if cErr != nil {
if err != nil {
log.Errorln(cErr)
return err
}
return cErr
}
return err
}
package client_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func TestSelector(t *testing.T) {
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
var mu sync.Mutex
selectedCount := 0
var wg sync.WaitGroup
selectAndDone := func(c *client.Etcd) {
defer wg.Done()
selected, err := c.Select()
if err != nil {
panic(err)
}
if selected {
mu.Lock()
selectedCount++
mu.Unlock()
err = c.Done()
if err != nil {
t.Fatal(err)
}
}
}
c0 := client.NewEtcd(endpoint)
c1 := client.NewEtcd(endpoint)
c2 := client.NewEtcd(endpoint)
c3 := client.NewEtcd(endpoint)
wg.Add(3)
go selectAndDone(c0)
go selectAndDone(c1)
go selectAndDone(c2)
wg.Wait()
// simulate trainer crashed and restarted after the
// initialization process.
wg.Add(1)
go selectAndDone(c3)
wg.Wait()
mu.Lock()
if selectedCount != 1 {
t.Fatal("selected count wrong:", selectedCount)
}
mu.Unlock()
err = c0.Close()
if err != nil {
t.Fatal(err)
}
err = c1.Close()
if err != nil {
t.Fatal(err)
}
err = c2.Close()
if err != nil {
t.Fatal(err)
}
err = c3.Close()
if err != nil {
t.Fatal(err)
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
...@@ -20,16 +34,19 @@ const ( ...@@ -20,16 +34,19 @@ const (
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information // PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/" PsCheckpoint = "/checkpoints/"
retryTimeout = 5 * time.Second
) )
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination. // tolerance, service registry and coordination.
type EtcdClient struct { type EtcdClient struct {
numPservers int numPservers int
etcdEndpoints string endpoints string
etcdClient *clientv3.Client client *clientv3.Client
// etcdTimeout is also used as retry intervals. sess *concurrency.Session
etcdTimeout time.Duration dialTimeout time.Duration
ttlSec int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string externalIP string
// desired number of pservers in the job. // desired number of pservers in the job.
...@@ -38,19 +55,19 @@ type EtcdClient struct { ...@@ -38,19 +55,19 @@ type EtcdClient struct {
} }
// NewEtcdClient creates an EtcdClient // NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
return &EtcdClient{ return &EtcdClient{
etcdTimeout: timeout, dialTimeout: dialtimeout,
numPservers: numPservers, ttlSec: ttlSec,
etcdEndpoints: endpoints, numPservers: numPservers,
endpoints: endpoints,
} }
} }
// Register registers the pserver on etcd // Register registers the pserver on etcd
// //
// Register returns the index of the current pserver. // Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) { func (e *EtcdClient) Register(port int) (int, error) {
var err error var err error
e.externalIP, err = networkhelper.GetExternalIP() e.externalIP, err = networkhelper.GetExternalIP()
if err != nil { if err != nil {
...@@ -58,19 +75,26 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -58,19 +75,26 @@ func (e *EtcdClient) Register() (int, error) {
} }
// initialize connection to etcd. // initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",") ep := strings.Split(e.endpoints, ",")
for { for {
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: ep, Endpoints: ep,
DialTimeout: e.etcdTimeout, DialTimeout: e.dialTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("connect to etcd error: %v", err) log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.etcdClient = cli e.client = cli
log.Debugf("inited client to %s", e.etcdEndpoints) sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil {
log.Errorf("create etcd session error: %v", err)
time.Sleep(retryTimeout)
continue
}
e.sess = sess
log.Debugf("inited client to %s", e.endpoints)
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write // init /ps_desired using transaction, for multiple pservers may want to write
...@@ -81,7 +105,7 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -81,7 +105,7 @@ func (e *EtcdClient) Register() (int, error) {
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
...@@ -92,18 +116,18 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -92,18 +116,18 @@ func (e *EtcdClient) Register() (int, error) {
// wait and set s.desired init value // wait and set s.desired init value
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired) resp, err := e.client.Get(ctx, PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err) log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
if len(resp.Kvs) != 0 { if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err) log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change // NOTE: wait util ps_desired value change
continue continue
} }
...@@ -116,11 +140,11 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -116,11 +140,11 @@ func (e *EtcdClient) Register() (int, error) {
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error var err error
pserverIdx, err = e.registerPserverEtcd(ctx) pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
...@@ -130,19 +154,19 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -130,19 +154,19 @@ func (e *EtcdClient) Register() (int, error) {
} }
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired) dsStr := c.Get(PsDesired)
if dsStr == "" { if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers)) c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
} }
return nil return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
} }
// registerPserverEtcd registers pserver node on etcd using transaction. // registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { _, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
...@@ -151,35 +175,20 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -151,35 +175,20 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" { if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info // find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID)) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished") log.Debug("register finished")
idx = i idx = i
registered = true registered = true
break break
} }
} }
if registered == true { if registered {
return nil return nil
} }
return errors.New("not registerd, may due to already have enough pservers") return errors.New("not registered, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil { if err != nil {
...@@ -192,11 +201,12 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -192,11 +201,12 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
// GetKey gets the value by the specified key // GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key) resp, err := e.client.Get(ctx, key)
cancel() cancel()
if err != nil { if err != nil {
return []byte{}, err return []byte{}, err
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
return []byte{}, nil return []byte{}, nil
...@@ -206,12 +216,34 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ...@@ -206,12 +216,34 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
} }
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value)) var err error
if withLease {
_, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
} else {
_, err = e.client.Put(ctx, key, string(value))
}
cancel() cancel()
if err != nil { return err
return err }
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
var err error
if e.sess != nil {
err = e.sess.Close()
}
if e.client != nil {
newErr := e.client.Close()
if newErr != nil {
if err != nil {
log.Errorln(newErr)
} else {
err = newErr
}
}
} }
return nil return err
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
...@@ -14,15 +28,15 @@ import ( ...@@ -14,15 +28,15 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct { type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
elementType ElementType elementType ElementType
contentLen int
config []byte
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr { if p == nil {
return nil return nil
} }
...@@ -37,10 +51,11 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -37,10 +51,11 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer { func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{} o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType o.elementType = paramWithConfigs.Param.ElementType
o.contentLen = len(paramWithConfigs.Param.Content)
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State s := State
paramBufferSize := C.size_t(len(p.Content) / C.sizeof_float) paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": paramBufferSize, "ParamSize": paramBufferSize,
...@@ -56,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -56,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0]) cstate = unsafe.Pointer(&s[0])
} }
o.config = c
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s))) C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o return o
...@@ -78,7 +94,11 @@ func (o *optimizer) UpdateParameter(g Gradient) error { ...@@ -78,7 +94,11 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float) if o.contentLen != len(g.Content) {
return fmt.Errorf("Name: %s, parameter and gradient does not have same content len, parameter: %d, gradient: %d", g.Name, o.contentLen, len(g.Content))
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
if r != 0 { if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r) return fmt.Errorf("optimizer update returned error code: %d", r)
} }
...@@ -86,8 +106,8 @@ func (o *optimizer) UpdateParameter(g Gradient) error { ...@@ -86,8 +106,8 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
} }
func (o *optimizer) Cleanup() { func (o *optimizer) Cleanup() {
if unsafe.Pointer(o.opt) != nullPtr { if unsafe.Pointer(o.opt) != nil {
C.paddle_release_optimizer(o.opt) C.paddle_release_optimizer(o.opt)
o.opt = (*C.struct_paddle_optimizer)(nullPtr) o.opt = (*C.struct_paddle_optimizer)(nil)
} }
} }
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
......
此差异已折叠。
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver_test package pserver_test
import ( import (
...@@ -16,7 +30,7 @@ const ( ...@@ -16,7 +30,7 @@ const (
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -31,7 +45,7 @@ func TestServiceFull(t *testing.T) { ...@@ -31,7 +45,7 @@ func TestServiceFull(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var p1 pserver.Parameter var p1 pserver.Parameter
...@@ -40,40 +54,40 @@ func TestServiceFull(t *testing.T) { ...@@ -40,40 +54,40 @@ func TestServiceFull(t *testing.T) {
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var param pserver.Parameter var param pserver.Parameter
err = s.GetParam("param_b", &param) err = s.GetParam("param_b", &param)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
if !reflect.DeepEqual(param, p1) { if !reflect.DeepEqual(param, p1) {
t.FailNow() t.Fatal("not equal:", param, p1)
} }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil) err = s.SendGrad(g1, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.SendGrad(g2, nil) err = s.SendGrad(g2, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var param1 pserver.Parameter var param1 pserver.Parameter
err = s.GetParam("param_a", &param1) err = s.GetParam("param_a", &param1)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
// don't compare content, since it's already changed by // don't compare content, since it's already changed by
...@@ -82,39 +96,39 @@ func TestServiceFull(t *testing.T) { ...@@ -82,39 +96,39 @@ func TestServiceFull(t *testing.T) {
p.Content = nil p.Content = nil
if !reflect.DeepEqual(param1, p) { if !reflect.DeepEqual(param1, p) {
t.FailNow() t.Fatal("not equal:", param1, p)
} }
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err.Error() != pserver.AlreadyInitialized { if err.Error() != pserver.AlreadyInitialized {
t.FailNow() t.Fatal(err)
} }
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.Fatal(err)
} }
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -154,12 +168,12 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -154,12 +168,12 @@ func TestBlockUntilInitialized(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
wg.Wait() wg.Wait()
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(network_helper_test) go_test(network_helper_test)
endif() endif()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import "testing" import "testing"
......
...@@ -21,22 +21,15 @@ ...@@ -21,22 +21,15 @@
# #
# It same as PYTHONPATH=${YOUR_PYTHON_PATH}:$PYTHONPATH {exec...} # It same as PYTHONPATH=${YOUR_PYTHON_PATH}:$PYTHONPATH {exec...}
# #
PYPATH=""
if ! python -c "import paddle" >/dev/null 2>/dev/null; then set -x
PYPATH="" while getopts "d:" opt; do
set -x case $opt in
while getopts "d:" opt; do d)
case $opt in PYPATH=$OPTARG
d) ;;
PYPATH=$OPTARG esac
;; done
esac shift $(($OPTIND - 1))
done export PYTHONPATH=$PYPATH:$PYTHONPATH
shift $(($OPTIND - 1)) $@
export PYTHONPATH=$PYPATH:$PYTHONPATH
$@
else
echo "paddle package is already in your PYTHONPATH. But unittest need a clean environment."
echo "Please uninstall paddle package before start unittest. Try to 'pip uninstall paddle'"
exit 1
fi
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const { ...@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double Evaluator::getValue(const std::string name) const { double Evaluator::getValue(const std::string name) const {
paddle::Error err; paddle::Error err;
double v = m->rawPtr->getValue(name, &err); double v = m->rawPtr->getValue(name, &err);
if (err) { if (!err.isOK()) {
throw std::runtime_error(err.msg()); throw std::runtime_error(err.msg());
} }
return v; return v;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册