提交 fc2e6f1c 编写于 作者: C caoying03

Merge branch 'develop' into print_attention_weight

#!/bin/bash
set -e
readonly VERSION="3.8"
version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1
fi
clang-format $@
......@@ -24,4 +24,5 @@ cmake-build-*
python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake
paddle/.timestamp
python/paddlepaddle.egg-info/
......@@ -17,14 +17,20 @@
- id: detect-private-key
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
- repo: local
hooks:
- id: clang-formater
- repo: https://github.com/dnephin/pre-commit-golang
sha: e4693a4c282b4fc878eda172a929f7a6508e7d16
- id: clang-format-with-version-check
name: clang-format
description: Format files with ClangFormat.
entry: ./.clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
- id: go-fmt
files: (.*\.go)
- id: go-lint
files: (.*\.go)
- id: go-fmt
types:
- go
- id: gometalinter
types:
- go
......@@ -4,6 +4,7 @@ cache:
- $HOME/.ccache
- $HOME/.cache/pip
- $TRAVIS_BUILD_DIR/build/third_party
- $TRAVIS_BUILD_DIR/build_android/third_party
sudo: required
dist: trusty
os:
......@@ -11,6 +12,7 @@ os:
env:
- JOB=build_doc
- JOB=check_style
- JOB=build_android
addons:
apt:
packages:
......@@ -35,10 +37,12 @@ before_install:
- if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
- pip install rarfile
- pip install -r $TRAVIS_BUILD_DIR/python/requirements.txt
- pip install wheel sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit LinkChecker
- curl https://glide.sh/get | bash
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- go get -u github.com/alecthomas/gometalinter
- gometalinter --install
- |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script:
......
......@@ -13,10 +13,9 @@
# limitations under the License
cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR})
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
include(system)
......@@ -37,6 +36,8 @@ include(simd)
################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND})
option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND})
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
......@@ -54,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
......@@ -75,6 +77,10 @@ if(ANDROID)
"Disable PYTHON when cross-compiling for Android" FORCE)
set(WITH_RDMA OFF CACHE STRING
"Disable RDMA when cross-compiling for Android" FORCE)
set(WITH_MKLDNN OFF CACHE STRING
"Disable MKLDNN when cross-compiling for Android" FORCE)
set(WITH_MKLML OFF CACHE STRING
"Disable MKLML package when cross-compiling for Android" FORCE)
endif(ANDROID)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
......@@ -88,6 +94,7 @@ endif()
########################################################################################
include(external/mklml) # download mklml package
include(external/zlib) # download, build, install zlib
include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog
......@@ -95,6 +102,7 @@ include(external/gtest) # download, build, install gtest
include(external/protobuf) # download, build, install protobuf
include(external/python) # download, build, install python
include(external/openblas) # download, build, install openblas
include(external/mkldnn) # download, build, install mkldnn
include(external/swig) # download, build, install swig
include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any
......@@ -114,8 +122,8 @@ include(version) # set PADDLE_VERSION
include(coveralls) # set code coverage
include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${PADDLE_SOURCE_DIR}")
include_directories("${PADDLE_SOURCE_DIR}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
include_directories(${Boost_INCLUDE_DIRS})
......@@ -130,14 +138,19 @@ set(EXTERNAL_LIBS
)
if(WITH_GPU)
list(APPEND EXTERNAL_LIB ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIB ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
endif(NOT WITH_DSO)
endif(WITH_GPU)
if(WITH_MKLDNN)
list(APPEND EXTERNAL_LIBS ${MKLDNN_LIB} ${MKLDNN_IOMP_LIB})
endif()
if(USE_NNPACK)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB} "rt")
include(external/nnpack)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIBS})
endif(USE_NNPACK)
add_subdirectory(proto)
......@@ -152,10 +165,12 @@ if(WITH_GOLANG)
add_subdirectory(go)
endif(WITH_GOLANG)
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
add_subdirectory(paddle)
if(WITH_PYTHON)
add_subdirectory(python)
endif()
if(WITH_DOC)
add_subdirectory(doc)
endif()
......@@ -25,27 +25,26 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \
apt-get install -y \
git python-pip python-dev openssh-server bison \
wget unzip tar xz-utils bzip2 gzip coreutils ntp \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-numpy python-matplotlib gcc g++ \
automake locales clang-format-3.8 swig doxygen cmake \
python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \
apt-get clean -y
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \
tar -xz -C /usr/local && \
mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz
mkdir /root/gopath/src
ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# install glide
RUN curl -q https://glide.sh/get | sh
RUN curl -s -q https://glide.sh/get | sh
# git credential to skip password typing
RUN git config --global credential.helper store
......@@ -56,19 +55,23 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter
# version util jupyter fixes this issue.
RUN pip install --upgrade pip && \
pip install -U 'protobuf==3.1.0' && \
pip install -U wheel pillow BeautifulSoup && \
pip install -U wheel && \
pip install -U docopt PyYAML sphinx && \
pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \
pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \
pip install -U sphinx-rtd-theme==0.1.9 recommonmark
RUN pip install pre-commit 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install rarfile
pip install opencv-python
COPY ./python/requirements.txt /root/
RUN pip install -r /root/requirements.txt
# To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use
# the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2
RUN apt-get install -y libssl-dev libffi-dev
RUN pip install certifi urllib3[secure]
# Install woboq_codebrowser to /woboq
RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \
(cd /woboq \
......
......@@ -14,6 +14,17 @@ RUN apt-get update && \
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
apt-get clean -y
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# git credential to skip password typing
RUN git config --global credential.helper store
......
......@@ -72,7 +72,7 @@ We provide [English](http://doc.paddlepaddle.org/develop/doc/) and
- [Deep Learning 101](http://book.paddlepaddle.org/index.html)
You might want to start from the this online interactive book that can run in Jupyter Notebook.
You might want to start from this online interactive book that can run in Jupyter Notebook.
- [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html)
......
......@@ -15,23 +15,44 @@
set(CBLAS_FOUND OFF)
## Find MKL First.
set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs")
set(MKL_ROOT ${INTEL_ROOT}/mkl CACHE PATH "Folder contains MKL")
## Find MKLML First.
if(WITH_MKLML AND MKLML_INC_DIR AND MKLML_LIB)
set(CBLAS_FOUND ON)
set(CBLAS_PROVIDER MKLML)
set(CBLAS_INC_DIR ${MKLML_INC_DIR})
set(CBLAS_LIBRARIES ${MKLML_LIB})
add_definitions(-DPADDLE_USE_MKLML)
add_definitions(-DLAPACK_FOUND)
message(STATUS "Found cblas and lapack in MKLML "
"(include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
return()
endif()
## Then find MKL.
set(INTEL_MKL_ROOT "/opt/intel/mkl" CACHE PATH "Folder contains intel mkl libs")
set(MKL_ROOT $ENV{MKL_ROOT} CACHE PATH "Folder contains env MKL")
set(MKL_INCLUDE_SEARCH_PATHS
${MKL_ROOT}/include
${INTEL_MKL_ROOT}/include)
set(MKL_LIB_SEARCH_PATHS
${MKL_ROOT}/lib
${MKL_ROOT}/lib/intel64
${INTEL_MKL_ROOT}/lib
${INTEL_MKL_ROOT}/lib/intel64)
find_path(MKL_INC_DIR mkl.h PATHS
${MKL_ROOT}/include)
${MKL_INCLUDE_SEARCH_PATHS})
find_path(MKL_LAPACK_INC_DIR mkl_lapacke.h PATHS
${MKL_ROOT}/include)
${MKL_INCLUDE_SEARCH_PATHS})
find_library(MKL_CORE_LIB NAMES mkl_core PATHS
${MKL_ROOT}/lib
${MKL_ROOT}/lib/intel64)
${MKL_LIB_SEARCH_PATHS})
find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS
${MKL_ROOT}/lib
${MKL_ROOT}/lib/intel64)
${MKL_LIB_SEARCH_PATHS})
find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS
${MKL_ROOT}/lib
${MKL_ROOT}/lib/intel64)
${MKL_LIB_SEARCH_PATHS})
if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64)
set(CBLAS_FOUND ON)
......
......@@ -28,6 +28,10 @@ if(NOT WITH_TIMER)
add_definitions(-DPADDLE_DISABLE_TIMER)
endif(NOT WITH_TIMER)
if(USE_EIGEN_FOR_BLAS)
add_definitions(-DPADDLE_USE_EIGEN_FOR_BLAS)
endif(USE_EIGEN_FOR_BLAS)
if(NOT WITH_PROFILER)
add_definitions(-DPADDLE_DISABLE_PROFILER)
endif(NOT WITH_PROFILER)
......@@ -67,6 +71,28 @@ else()
include_directories(${CUDA_TOOLKIT_INCLUDE})
endif(NOT WITH_GPU)
if(WITH_MKLDNN)
add_definitions(-DPADDLE_USE_MKLDNN)
if (WITH_MKLML AND MKLDNN_IOMP_DIR)
message(STATUS "Enable Intel OpenMP at ${MKLDNN_IOMP_DIR}")
set(OPENMP_FLAGS "-fopenmp")
set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}")
else()
find_package(OpenMP)
if(OPENMP_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
else()
message(WARNING "Can not find OpenMP."
"Some performance features in MKLDNN may not be available")
endif()
endif()
endif(WITH_MKLDNN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
......@@ -102,12 +128,19 @@ if(WITH_GOLANG)
message(FATAL_ERROR "no glide executeble found: $ENV{GOPATH}/bin/glide")
endif()
add_custom_target(go_vendor)
add_custom_command(TARGET go_vendor
# this command will only run when the file it depends is missing
# or has changed, or the output is missing.
add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/glide
COMMAND env GOPATH=${GOPATH} ${GLIDE} install
COMMAND touch ${CMAKE_BINARY_DIR}/glide
DEPENDS ${PADDLE_SOURCE_DIR}/go/glide.lock
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go"
)
add_dependencies(go_vendor go_path)
)
# depends on the custom command which outputs
# ${CMAKE_BINARY_DIR}/glide, the custom command does not need to
# run every time this target is built.
add_custom_target(go_vendor DEPENDS ${CMAKE_BINARY_DIR}/glide go_path)
endif()
endif(WITH_GOLANG)
......@@ -27,7 +27,8 @@ set(IGNORE_PATTERN
.*cblas\\.h.*
.*\\.pb\\.txt
.*LtrDataProvider.*
.*MultiDataProvider.*)
.*MultiDataProvider.*
.*pb.*)
# add_style_check_target
#
......@@ -41,27 +42,21 @@ macro(add_style_check_target TARGET_NAME)
if(WITH_STYLE_CHECK)
set(SOURCES_LIST ${ARGN})
list(REMOVE_DUPLICATES SOURCES_LIST)
list(SORT SOURCES_LIST)
foreach(filename ${SOURCES_LIST})
set(LINT ON)
foreach(pattern ${IGNORE_PATTERN})
if(filename MATCHES ${pattern})
message(STATUS "DROP LINT ${filename}")
set(LINT OFF)
list(REMOVE_ITEM SOURCES_LIST ${filename})
endif()
endforeach()
if(LINT MATCHES ON)
get_filename_component(base_filename ${filename} NAME)
set(CUR_GEN ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.cpplint)
add_custom_command(OUTPUT ${CUR_GEN}
PRE_BUILD
COMMAND env ${py_env} "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py"
"--filter=${STYLE_FILTER}"
"--write-success=${CUR_GEN}" ${filename}
DEPENDS ${filename}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endforeach()
if(SOURCES_LIST)
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/scripts/cpplint.py"
"--filter=${STYLE_FILTER}"
${SOURCES_LIST}
COMMENT "cpplint: Checking source code style"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endif()
endmacro()
......@@ -108,6 +108,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
SET(CMAKE_SYSTEM_PROCESSOR aarch64)
ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF()
......@@ -166,7 +167,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
......@@ -193,6 +194,10 @@ ELSE()
SET(CMAKE_ANDROID_STANDALONE_TOOLCHAIN ${ANDROID_STANDALONE_TOOLCHAIN})
ENDIF()
SET(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ABI})
SET(CMAKE_ANDROID_ARM_MODE ${ANDROID_ARM_MODE})
SET(CMAKE_ANDROID_ARM_NEON ${ANDROID_ARM_NEON})
IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(CMAKE_ANDROID_ARM_MODE ${ANDROID_ARM_MODE})
IF(ANDROID_ABI STREQUAL "armeabi-v7a")
SET(CMAKE_ANDROID_ARM_NEON ${ANDROID_ARM_NEON})
ENDIF()
ENDIF()
ENDIF()
......@@ -2,7 +2,7 @@ if(NOT WITH_GPU)
return()
endif()
set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT")
set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT")
find_path(CUDNN_INCLUDE_DIR cudnn.h
PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include
$ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
......
......@@ -7,8 +7,8 @@ INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any)
ExternalProject_Add(
extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/thelink2012/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020"
GIT_REPOSITORY "https://github.com/PaddlePaddle/any.git"
GIT_TAG "15595d8324be9e8a9a80d9ae442fdd12bd66df5d"
PREFIX ${ANY_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
......@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048"
GIT_TAG "master"
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
......@@ -28,7 +28,14 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
ExternalProject_Add(
extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/gflags/gflags.git"
# TODO(yiwang): The annoying warnings mentioned in
# https://github.com/PaddlePaddle/Paddle/issues/3277 are caused by
# gflags. I fired a PR https://github.com/gflags/gflags/pull/230
# to fix it. Before it gets accepted by the gflags team, we use
# my personal fork, which contains above fix, temporarily. Let's
# change this back to the official Github repo once my PR is
# merged.
GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git"
PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
......@@ -52,6 +52,7 @@ ExternalProject_Add(
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog)
ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog)
......@@ -34,9 +34,15 @@ IF(WITH_TESTING)
"${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE)
ENDIF(WIN32)
IF(WITH_MKLML)
# wait for mklml downloading completed
SET(GTEST_DEPENDS ${MKLML_PROJECT})
ENDIF()
ExternalProject_Add(
extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${GTEST_DEPENDS}
GIT_REPOSITORY "https://github.com/google/googletest.git"
GIT_TAG "release-1.8.0"
PREFIX ${GTEST_SOURCES_DIR}
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_MKLDNN})
return()
ENDIF(NOT ${WITH_MKLDNN})
INCLUDE(ExternalProject)
SET(MKLDNN_PROJECT "extern_mkldnn")
SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with MKLDNN in Paddle yet."
"Force WITH_MKLDNN=OFF")
SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in Windows and MacOS" FORCE)
return()
ENDIF()
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE)
MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR})
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
SET(MKLDNN_MKLROOT ${MKLML_ROOT})
SET(MKLDNN_IOMP_LIB ${MKLML_IOMP_LIB})
SET(MKLDNN_IOMP_DIR ${MKLML_LIB_DIR})
MESSAGE(STATUS "Build MKLDNN with ${MKLDNN_MKLROOT}")
ENDIF()
ExternalProject_Add(
${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "v0.9"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DMKLROOT=${MKLDNN_MKLROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
-DMKLROOT:PATH=${MKLDNN_MKLROOT}
)
ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIB}")
LIST(APPEND external_project_dependencies mkldnn)
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_MKLML})
return()
ENDIF(NOT ${WITH_MKLML})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with MKLML in Paddle yet."
"Force WITH_MKLML=OFF")
SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(MKLML_PROJECT "extern_mklml")
SET(MKLML_VER "mklml_lnx_2018.0.20170720")
SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.9/${MKLML_VER}.tgz")
SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
SET(MKLML_DST_DIR "mklml")
SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
SET(MKLML_ROOT ${MKLML_INSTALL_DIR}/${MKLML_VER})
SET(MKLML_INC_DIR ${MKLML_ROOT}/include)
SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
INCLUDE_DIRECTORIES(${MKLML_INC_DIR})
FILE(WRITE ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(MKLML)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${MKLML_VER}\n"
" DESTINATION ${MKLML_DST_DIR})\n")
ExternalProject_Add(
${MKLML_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${MKLML_SOURCE_DIR}
DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate -qO- ${MKLML_URL} | tar xz -C ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLML_INSTALL_ROOT}
)
ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB})
ADD_DEPENDENCIES(mklml ${MKLML_PROJECT})
LIST(APPEND external_project_dependencies mklml)
......@@ -7,10 +7,24 @@ set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK")
find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include)
find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib)
find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib)
find_library(NNPACK_UKERNELS_LIB NAMES nnpack_ukernels PATHS ${NNPACK_ROOT}/lib)
find_library(NNPACK_CPUFEATURES_LIB NAMES cpufeatures PATHS ${NNPACK_ROOT}/lib)
if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB)
set(NNPACK_FOUND ON)
INCLUDE_DIRECTORIES(${NNPACK_INC_DIR})
set(NNPACK_LIBS)
list(APPEND NNPACK_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB})
if (NNPACK_UKERNELS_LIB)
list(APPEND NNPACK_LIBS ${NNPACK_UKERNELS_LIB})
endif()
if (NNPACK_CPUFEATURES_LIB)
list(APPEND NNPACK_LIBS ${NNPACK_CPUFEATURES_LIB})
endif()
if(NOT ANDROID)
list(APPEND NNPACK_LIBS "rt")
endif()
else()
message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})")
endif()
......@@ -69,9 +69,22 @@ ENDIF(NOT ${CBLAS_FOUND})
MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}")
INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
ADD_LIBRARY(cblas STATIC IMPORTED)
SET_PROPERTY(TARGET cblas PROPERTY IMPORTED_LOCATION ${CBLAS_LIBRARIES})
# FIXME(gangliao): generate cblas target to track all high performance
# linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
IF(${CBLAS_PROVIDER} MATCHES MKL)
ADD_LIBRARY(cblas SHARED ${dummyfile})
ELSE()
ADD_LIBRARY(cblas STATIC ${dummyfile})
ENDIF()
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas)
LIST(APPEND external_project_dependencies cblas)
ELSE()
IF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
ADD_DEPENDENCIES(cblas mklml)
ENDIF()
ENDIF(NOT ${CBLAS_FOUND})
......@@ -24,7 +24,6 @@ IF(WITH_PYTHON)
ENDIF(WITH_PYTHON)
SET(py_env "")
SET(USE_VIRTUALENV_FOR_TEST 1)
IF(PYTHONINTERP_FOUND)
find_python_module(pip REQUIRED)
find_python_module(numpy REQUIRED)
......
......@@ -110,7 +110,7 @@ set(COMMON_FLAGS
-Wno-error=literal-suffix
-Wno-error=sign-compare
-Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in Pybind11
-Wno-error=parentheses-equality # Warnings in pybind11
)
set(GPU_COMMON_FLAGS
......@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=literal-suffix
-Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
)
if (APPLE)
......@@ -189,6 +190,7 @@ endif()
# Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
list(APPEND CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
endif()
# Custom gpu architecture
......
......@@ -104,6 +104,7 @@ function(merge_static_libs TARGET_NAME)
foreach(lib ${libs})
list(APPEND libs_deps ${${lib}_LIB_DEPENDS})
endforeach()
list(REMOVE_DUPLICATES libs_deps)
if(APPLE) # Use OSX's libtool to merge archives
# To produce a library we need at least one source file.
......@@ -127,7 +128,7 @@ function(merge_static_libs TARGET_NAME)
# Get the file names of the libraries to be merged
set(libfiles ${libfiles} $<TARGET_FILE:${lib}>)
endforeach()
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a"
COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles})
else() # general UNIX: use "ar" to extract objects and re-add to a common lib
......@@ -145,11 +146,11 @@ function(merge_static_libs TARGET_NAME)
DEPENDS ${lib} ${objdir}
WORKING_DIRECTORY ${objdir})
# Empty dummy source file that goes into merged library
set(mergebase ${lib}.mergebase.c)
add_custom_command(OUTPUT ${mergebase}
COMMAND ${CMAKE_COMMAND} -E touch ${mergebase}
DEPENDS ${objlistfile})
# Empty dummy source file that goes into merged library
set(mergebase ${lib}.mergebase.c)
add_custom_command(OUTPUT ${mergebase}
COMMAND ${CMAKE_COMMAND} -E touch ${mergebase}
DEPENDS ${objlistfile})
list(APPEND mergebases "${mergebase}")
endforeach()
......@@ -184,6 +185,16 @@ function(cc_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
endif()
# cpplint code style
foreach(source_file ${cc_library_SRCS})
string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file})
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
endif()
endforeach()
add_style_check_target(${TARGET_NAME} ${cc_library_SRCS} ${cc_library_HEADERS})
else(cc_library_SRCS)
if (cc_library_DEPS)
merge_static_libs(${TARGET_NAME} ${cc_library_DEPS})
......@@ -234,6 +245,14 @@ function(nv_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_library_DEPS})
target_link_libraries(${TARGET_NAME} ${nv_library_DEPS})
endif()
# cpplint code style
foreach(source_file ${nv_library_SRCS})
string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file})
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
endif()
endforeach()
add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS})
else(nv_library_SRCS)
if (nv_library_DEPS)
merge_static_libs(${TARGET_NAME} ${nv_library_DEPS})
......@@ -285,8 +304,22 @@ function(go_library TARGET_NAME)
set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
endif()
# Add dummy code to support `make target_name` under Terminal Command
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
# This custom command will always run since it depends on a not
# existing file.
add_custom_command(
OUTPUT dummy_rebulid_${TARGET_NAME}
COMMAND cmake -E touch ${dummyfile}
)
# Create a custom target that depends on the custom command output
# file, so the custom command can be referenced as a dependency by
# `add_dependencies`.
add_custom_target(rebuild_${TARGET_NAME}
DEPENDS dummy_rebulid_${TARGET_NAME}
)
# Add dummy code to support `make target_name` under Terminal Command
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
if (go_library_SHARED OR go_library_shared)
add_library(${TARGET_NAME} SHARED ${dummyfile})
......@@ -297,6 +330,12 @@ function(go_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${go_library_DEPS})
endif(go_library_DEPS)
# The "source file" of the library is `${dummyfile}` which never
# change, so the target will never rebuild. Make the target depends
# on the custom command that touches the library "source file", so
# rebuild will always happen.
add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME})
set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}")
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
......@@ -337,7 +376,7 @@ function(go_test TARGET_NAME)
string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test -race
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
".${CMAKE_CURRENT_SOURCE_REL_DIR}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
......@@ -364,3 +403,16 @@ function(py_proto_compile TARGET_NAME)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction()
function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python
python2 ${py_test_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
......@@ -12,7 +12,7 @@ set(CPACK_PACKAGE_DESCRIPTION "")
set(CPACK_DEBIAN_PACKAGE_DEPENDS "libpython2.7-dev, libstdc++6, python-pip, curl, libgfortran3, python-pip-whl")
set(CPACK_DEBIAN_PACKAGE_SECTION Devel)
set(CPACK_DEBIAN_PACKAGE_VERSION ${PADDLE_VERSION})
set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${PROJ_ROOT}/paddle/scripts/deb/postinst")
set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${PADDLE_SOURCE_DIR}/paddle/scripts/deb/postinst")
#set(CPACK_GENERATOR "DEB")
# Start cpack
include (CMakePackageConfigHelpers)
......
......@@ -118,7 +118,6 @@ endfunction()
macro(add_unittest_without_exec TARGET_NAME)
add_executable(${TARGET_NAME} ${ARGN})
link_paddle_test(${TARGET_NAME})
add_style_check_target(${TARGET_NAME} ${ARGN})
endmacro()
# add_unittest
......@@ -142,17 +141,20 @@ endmacro()
function(create_resources res_file output_file)
add_custom_command(
OUTPUT ${output_file}
COMMAND python ARGS ${PROJ_ROOT}/cmake/make_resource.py ${res_file} ${output_file}
DEPENDS ${res_file} ${PROJ_ROOT}/cmake/make_resource.py)
COMMAND python ARGS ${PADDLE_SOURCE_DIR}/cmake/make_resource.py ${res_file} ${output_file}
DEPENDS ${res_file} ${PADDLE_SOURCE_DIR}/cmake/make_resource.py)
endfunction()
# Create a python unittest using run_python_tests.sh,
# which takes care of making correct running environment
function(add_python_test TEST_NAME)
add_test(NAME ${TEST_NAME}
COMMAND env PADDLE_PACKAGE_DIR=${PADDLE_PYTHON_PACKAGE_DIR}
bash ${PROJ_ROOT}/paddle/scripts/run_python_tests.sh
${USE_VIRTUALENV_FOR_TEST} ${PYTHON_EXECUTABLE} ${ARGN}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
foreach(arg ${ARGN})
get_filename_component(py_fn ${arg} NAME_WE)
set(TRG_NAME ${TEST_NAME}_${py_fn})
add_test(NAME ${TRG_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_PACKAGE_DIR}
python2 ${arg}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endforeach()
endfunction()
......@@ -4,7 +4,7 @@ set(tmp_version "HEAD")
while ("${PADDLE_VERSION}" STREQUAL "")
execute_process(
COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 ${tmp_version}
WORKING_DIRECTORY ${PROJ_ROOT}
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE GIT_TAG_NAME
RESULT_VARIABLE GIT_RESULT
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
......
......@@ -104,6 +104,11 @@ cross_channel_norm
------------------
.. autoclass:: paddle.v2.layer.cross_channel_norm
:noindex:
row_l2_norm
-----------
.. autoclass:: paddle.v2.layer.row_l2_norm
:noindex:
Recurrent Layers
================
......@@ -198,6 +203,10 @@ identity_projection
.. autoclass:: paddle.v2.layer.identity_projection
:noindex:
slice_projection
-------------------
.. autoclass:: paddle.v2.layer.slice_projection
:noindex:
table_projection
----------------
......@@ -248,6 +257,16 @@ seq_concat
.. autoclass:: paddle.v2.layer.seq_concat
:noindex:
kmax_sequence_score
-------------------
.. autoclass:: paddle.v2.layer.kmax_sequence_score
:noindex:
sub_nested_seq
--------------
.. autoclass:: paddle.v2.layer.sub_nested_seq
:noindex:
Reshaping Layers
================
......@@ -316,6 +335,11 @@ scaling
.. autoclass:: paddle.v2.layer.scaling
:noindex:
clip
----
.. autoclass:: paddle.v2.layer.clip
:noindex:
slope_intercept
---------------
.. autoclass:: paddle.v2.layer.slope_intercept
......@@ -338,6 +362,11 @@ trans
.. autoclass:: paddle.v2.layer.trans
:noindex:
scale_shift
-----------
.. autoclass:: paddle.v2.layer.scale_shift
:noindex:
Sampling Layers
===============
......@@ -474,6 +503,11 @@ prelu
.. autoclass:: paddle.v2.layer.prelu
:noindex:
gated_unit
-----------
.. autoclass:: paddle.v2.layer.gated_unit
:noindex:
Detection output Layer
======================
......
## Auto Gradient Checker Design
## Backgraound:
- Operator forward computing is easy to check if the result is right because it has a clear definition. **But** backpropagation is a notoriously difficult algorithm to debug and get right:
- 1. you should get the right backpropagation formula according to the forward computation.
- 2. you should implement it right in CPP.
- 3. it's difficult to prepare test data.
- Auto gradient check gets a numeric gradient by forward Operator and use it as a reference of the backward Operator's result. It has several advantages:
- 1. numeric gradient checker only need forward operator.
- 2. user only need to prepare the input data for forward Operator.
## Mathematical Theory
The following two document from stanford has a detailed explanation of how to get numeric gradient and why it's useful.
- [Gradient checking and advanced optimization(en)](http://deeplearning.stanford.edu/wiki/index.php/Gradient_checking_and_advanced_optimization)
- [Gradient checking and advanced optimization(cn)](http://ufldl.stanford.edu/wiki/index.php/%E6%A2%AF%E5%BA%A6%E6%A3%80%E9%AA%8C%E4%B8%8E%E9%AB%98%E7%BA%A7%E4%BC%98%E5%8C%96)
## Numeric Gradient Implementation
### Python Interface
```python
def get_numeric_gradient(op,
input_values,
output_name,
input_to_check,
delta=0.005,
local_scope=None):
"""
Get Numeric Gradient for an operator's input.
:param op: C++ operator instance, could be an network
:param input_values: The input variables. Should be an dictionary, key is
variable name. Value is numpy array.
:param output_name: The final output variable name.
:param input_to_check: The input variable need to get gradient.
:param delta: The perturbation value for numeric gradient method. The
smaller delta is, the more accurate result will get. But if that delta is
too small, it could occur numerical stability problem.
:param local_scope: The local scope used for get_numeric_gradient.
:return: The gradient array in numpy format.
"""
```
### Explaination:
- Why need `output_name`
- One Operator may have multiple Output, you can get independent gradient from each Output. So user should set one output to calculate.
- Why need `input_to_check`
- One operator may have multiple inputs. Gradient Op can calculate the gradient of these Inputs at the same time. But Numeric Gradient needs to calculate them one by one. So `get_numeric_gradient` is designed to calculate the gradient for one input. If you need to compute multiple inputs, you can call `get_numeric_gradient` multiple times.
### Core Algorithm Implementation
```python
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
# add delta to it, run op and then get the sum of the result tensor.
x_pos = origin + delta
tensor_to_check.set_float_element(i, x_pos)
y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor.
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
# restore old value
tensor_to_check.set_float_element(i, origin)
# compute the gradient of this element and store it into a numpy array.
gradient_flat[i] = (y_pos - y_neg) / delta / 2
# reshape the gradient result to the shape of the source tensor.
return gradient_flat.reshape(tensor_to_check.get_dims())
```
## Auto Graident Checker Framework
Each Operator Kernel has three kinds of Gradient:
- 1. Numeric Gradient
- 2. CPU Operator Gradient
- 3. GPU Operator Gradient(if supported)
Numeric Gradient Only relies on forward Operator. So we use Numeric Gradient as the reference value.
- 1. calculate the numeric gradient.
- 2. calculate CPU kernel Gradient with the backward Operator and compare it with the numeric gradient.
- 3. calculate GPU kernel Gradient with the backward Operator and compare it with the numeric gradient.(if support GPU)
#### Python Interface
```python
def check_grad(self,
forward_op,
input_vars,
inputs_to_check,
output_name,
no_grad_set=None,
only_cpu=False,
max_relative_error=0.005):
"""
:param forward_op: used to create backward_op
:param input_vars: numpy value of input variable. The following
computation will use these variables.
:param inputs_to_check: inputs var names that should check gradient.
:param output_name: output name that used to
:param max_relative_error: The relative tolerance parameter.
:param no_grad_set: used when create backward ops
:param only_cpu: only compute and check gradient on cpu kernel.
:return:
"""
```
### How to check if two numpy array is close enough?
if `abs_numeric_grad` is nearly zero, then use abs error for numeric_grad, not relative
```python
numeric_grad = ...
operator_grad = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
abs_numeric_grad = numpy.abs(numeric_grad)
# if abs_numeric_grad is nearly zero, then use abs error for numeric_grad, not relative
# error.
abs_numeric_grad[abs_numeric_grad < 1e-3] = 1
diff_mat = numpy.abs(abs_numeric_grad - operator_grad) / abs_numeric_grad
max_diff = numpy.max(diff_mat)
```
#### Notes:
1,The Input data for auto gradient checker should be reasonable to avoid numeric problem.
#### Refs:
- [Gradient checking and advanced optimization(en)](http://deeplearning.stanford.edu/wiki/index.php/Gradient_checking_and_advanced_optimization)
- [Gradient checking and advanced optimization(cn)](http://ufldl.stanford.edu/wiki/index.php/%E6%A2%AF%E5%BA%A6%E6%A3%80%E9%AA%8C%E4%B8%8E%E9%AB%98%E7%BA%A7%E4%BC%98%E5%8C%96)
# Alalysis of large model distributed training in Paddle
***NOTE: This is only some note for how we implemeted this scheme in V1, not a new design.***
## What is it
We often encounter cases that the embedding layer parameters(sparse) are so large that we can not store it in the trainer's memory when training. So we need to put them to several servers, and fetch them row by row instead of fetch all of the parameters.
## How to use
Specify command-line argument like `--loadsave_parameters_in_pserver=true --ports_num_for_sparse=1 --use_old_updater=1` when starting the paddle trainer. And also add something like `--ports_num_for_sparse=1 --pserver_num_threads=5` when starting pserver processes.
Accrodingly, configure your embedding layers like:
```python
SPARSE_REMOTE=True
w1 = data_layer(name="w1", size=dict_size)
emb1 = embedding_layer(input=w1, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
w2 = data_layer(name="w2", size=dict_size)
emb2 = embedding_layer(input=w2, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
...
```
## Implementation details
```c++
enum MatType {
MAT_NORMAL,
MAT_NORMAL_SHARED,
MAT_VALUE_SHARED,
MAT_SPARSE_ROW_IDS,
MAT_SPARSE_ROW_AUTO_GROW,
MAT_CACHE_ROW,
MAT_SPARSE_ROW,
MAT_SPARSE_ROW_PREFETCH,
MAT_SPARSE_ROW_PREFETCH_FULL_SIZE,
};
```
`MAT_SPARSE_ROW_PREFETCH` is what we use when configured to fetch only row of matrix when training.
In `trainer_internal.cpp:L93 trainOneBatch`:
```c++
if (config_->getOptConfig().use_sparse_remote_updater()) {
REGISTER_TIMER("prefetch");
gradientMachine_->prefetch(inArgs);
parameterUpdater_->getParametersRemote();
}
```
When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver.
In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
```c++
if (fullSize) {
...
} else {
getParams = [&] {
parameterClient_->getParameterSparse(
/* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType);
};
applyL1 = [](Parameter& para, real decayRate) {
para.getMat(PARAMETER_VALUE)->applyL1(/*lr=*/1.0f, decayRate);
};
}
```
Calling `parameterClient_->getParameterSparse` will do remote call to pserver's `getParameterSparse`:
```c++
void ParameterServer2::getParameterSparse(const SendParameterRequest& request,
std::vector<Buffer>& inputBuffers,
SendParameterResponse* response,
std::vector<Buffer>* outputBuffers) {
(void)inputBuffers;
auto& buffer = *readWriteBuffer_;
size_t numReals = 0;
for (const auto& block : request.blocks()) {
numReals += getParameterConfig(block).dims(1);
}
buffer.resize(numReals);
VLOG(3) << "pserver: getParameterSparse, numReals=" << numReals;
ReadLockGuard guard(parameterMutex_);
size_t offset = 0;
for (const auto& block : request.blocks()) {
size_t width = getParameterConfig(block).dims(1);
Buffer buf = {buffer.data() + offset, width};
int type = request.send_back_parameter_type();
sendBackParameterSparse(block, type, response, &buf, width, outputBuffers);
offset += width;
}
}
```
`getParameterConfig(block).dims(1)` returns the width of the current "parameter block"(a shard of parameter object),
then `getParameterSparse` remote call returns only one row of data to the client.
......@@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
elect one trainer. When not using etcd, unique trainer IDs will be
given by the administrator, the trainer whose ID is "0" is elected to
save the model.
etcd, trainer ID is a randomly generated UUID, the trainer will
contact the master server requesting to save the model, and find out
if itself is elected. When the master server is not used, unique
trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path
......
# Intel® MKL-DNN on PaddlePaddle: Design Doc
我们计划将Intel深度神经网络数学库(**MKL-DNN**\[[1](#references)\])集成到PaddlePaddle,充分展现英特尔平台的优势,有效提升PaddlePaddle在英特尔架构上的性能。
我们短期内的基本目标是:
- 完成常用layer的MKL-DNN实现。
- 完成常见深度神经网络VGG,GoogLeNet 和 ResNet的MKL-DNN实现。
## Contents
- [Overview](#overview)
- [Actions](#actions)
- [CMake](#cmake)
- [Layers](#layers)
- [Activations](#activations)
- [Unit Tests](#unit-tests)
- [Protobuf Messages](#protobuf-messages)
- [Python API](#python-api)
- [Demos](#demos)
- [Benchmarking](#benchmarking)
- [Others](#others)
- [Design Concerns](#design-concerns)
## Overview
我们会把MKL-DNN作为第三方库集成进PaddlePaddle,整体框架图
<div align="center">
<img src="image/overview.png" width=350><br/>
Figure 1. PaddlePaddle on IA.
</div>
## Actions
我们把集成方案大致分为了如下几个方面。
### CMake
我们会在`CMakeLists.txt`中会添加`WITH_MKLDNN`的选项,当设置这个值为`ON`的时候会启用编译MKL-DNN功能。同时会自动开启OpenMP用于提高MKL-DNN的性能。
同时,我们会引入`WITH_MKLML`选项,用于选择是否使用MKL-DNN自带的MKLML安装包。这个安装包可以独立于MKL-DNN使用,但是建议在开启MKL-DNN的同时也打开MKLML的开关,这样才能发挥最好的性能。
所以,我们会在`cmake/external`目录新建`mkldnn.cmake``mklml.cmake`文件,它们会在编译PaddlePaddle的时候下载对应的软件包,并放到PaddlePaddle的third party目录中。
**备注**:当`WITH_MKLML=ON`的时候,会优先使用这个包作为PaddlePaddle的CBLAS和LAPACK库,所以会稍微改动`cmake/cblas.cmake`中的逻辑。
### Layers
所有MKL-DNN相关的C++ layers,都会按照PaddlePaddle的目录结构存放在
`paddle/gserver/layers`中,并且文件名都会一以*Mkldnn*开头。
所有MKL-DNN的layers都会继承于一个叫做`MkldnnLayer`的父类,该父类继承于PaddlePaddle的基类`Layer`
### Activations
由于在PaddlePaddle中,激活函数是独立于layer概念的,所以会在`paddle/gserver/activations`目录下添加一个`MkldnnActivation.h`文件定义一些用于MKL-DNN的接口,实现方法还是会在`ActivationFunction.cpp`文件。
### Unit Tests
会在`paddle/gserver/test`目录下添加`test_Mkldnn.cpp``MkldnnTester.*`用于MKL-DNN的测试。
Activation的测试,计划在PaddlePaddle原有的测试文件上直接添加新的测试type。
### Protobuf Messages
根据具体layer的需求可能会在`proto/ModelConfig.proto`里面添加必要的选项。
### Python API
目前只考虑**v1 API**
计划在`python/paddle/trainer/config_parser.py`里面添加`use_mkldnn`这个选择,方便用户选择使用MKL-DNN的layers。
具体实现方式比如:
```python
use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0)))
if use_mkldnn
self.layer_type = mkldnn_*
```
所有MKL-DNN的layer type会以*mkldnn_*开头,以示区分。
并且可能在`python/paddle/trainer_config_helper`目录下的`activations.py ``layers.py`里面添加必要的MKL-DNN的接口。
### Demos
会在`v1_api_demo`目录下添加一个`mkldnn`的文件夹,里面放入一些用于MKL-DNN测试的demo脚本。
### Benchmarking
会考虑添加部分逻辑在`benchmark/paddle/image/run.sh`,添加使用MKL-DNN的测试。
### Others
1. 如果在使用MKL-DNN的情况下,会把CPU的Buffer对齐为64。
2. 深入PaddlePaddle,寻找有没有其他可以优化的可能,进一步优化。比如可能会用OpenMP改进SGD的更新性能。
## Design Concerns
为了更好的符合PaddlePaddle的代码风格\[[2](#references)\],同时又尽可能少的牺牲MKL-DNN的性能\[[3](#references)\]
我们总结出一些特别需要注意的点:
1. 使用**deviceId_**。为了尽可能少的在父类Layer中添加变量或者函数,我们决定使用已有的`deviceId_`变量来区分layer的属性,定义`-2``MkldnnLayer`特有的设备ID。
2. 重写父类Layer的**init**函数,修改`deviceId_``-2`,代表这个layer是用于跑在MKL-DNN的环境下。
3. 创建`MkldnnMatrix`,用于管理MKL-DNN会用到的相关memory函数、接口以及会用的到格式信息。
4. 创建`MkldnnBase`,定义一些除了layer和memory相关的类和函数。包括MKL-DNN会用到`MkldnnStream``CpuEngine`,和未来可能还会用到`FPGAEngine`等。
5.**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue``mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。
6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。
7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。
8. 关于MKLDNN参数的保存。由于MKLDNN参数的格式与PaddlePaddle原有的格式存在不一样的情况,所以需要在保存参数时同时保存该格式信息。目前准备扩展[Header](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/parameter/Parameter.h#L247)里面的`int32_t version`。这个值不管是在v1还是在v2里面,一直保存的是0,所以可以充分利用这个信息,定义一个枚举处理所有MKLDNN的参数格式,从而`MKLDNNLayer`就可以从输入的参数中获取需要的格式信息。
## References
1. [Intel Math Kernel Library for Deep Neural Networks (Intel MKL-DNN)](https://github.com/01org/mkl-dnn "Intel MKL-DNN")
2. [原来的方案](https://github.com/PaddlePaddle/Paddle/pull/3096)会引入**nextLayer**的信息。但是在PaddlePaddle中,无论是重构前的layer还是重构后的op,都不会想要知道next layer/op的信息。
3. MKL-DNN的高性能格式与PaddlePaddle原有的`NCHW`不同(PaddlePaddle中的CUDNN部分使用的也是`NCHW`,所以不存在这个问题),所以需要引入一个转换方法,并且只需要在必要的时候转换这种格式,才能更好的发挥MKL-DNN的性能。
......@@ -11,6 +11,15 @@ Paddle每次发新的版本,遵循以下流程:
* 编译这个版本的Ubuntu Deb包。如果失败,修复Ubuntu Deb包编译问题,Patch号加一,返回第二步。
* 使用Regression Test List作为检查列表,测试Docker镜像/ubuntu安装包的功能正确性
* 如果失败,记录下所有失败的例子,在这个`release/版本号`分支中,修复所有bug后,Patch号加一,返回第二步
* 编译这个版本的python wheel包,并发布到pypi。
* 由于pypi.python.org目前遵循[严格的命名规范PEP 513](https://www.python.org/dev/peps/pep-0513),在使用twine上传之前,需要重命名wheel包中platform相关的后缀,比如将`linux_x86_64`修改成`manylinux1_x86_64`
* pypi上的package名称为paddlepaddle和paddlepaddle_gpu,如果要上传GPU版本的包,需要修改build/python/setup.py中,name: "paddlepaddle_gpu"并重新打包wheel包:`python setup.py bdist_wheel`
* 上传方法:
```
cd build/python
pip install twine
twine upload dist/[package to upload]
```
4. 第三步完成后,将`release/版本号`分支合入master分支,并删除`release/版本号`分支。将master分支的合入commit打上tag,tag为`版本号`。同时再将`master`分支合入`develop`分支。最后删除`release/版本号`分支。
5. 编译master分支的Docker发行镜像,发布到dockerhub。编译ubuntu的deb包,发布到github release页面
6. 协同完成Release Note的书写
......
......@@ -37,8 +37,8 @@ Scope is an association of a name to variable. All variables belong to `Scope`.
```cpp
class Scope {
public:
Variable* CreateVariable(const std::string& name);
const Variable* GetVariable(const std::string& name) const;
Variable* NewVar(const std::string& name);
const Variable* FindVar(const std::string& name) const;
private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
......@@ -58,12 +58,12 @@ class Scope {
public:
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const {
Variable* FindVar(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
return parent_->FindVar(name);
} else {
return nullptr;
}
......@@ -95,10 +95,10 @@ class Scope {
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);
// return nullptr if not found.
Variable* GetVariable(const std::string& name) const;
Variable* FindVar(const std::string& name) const;
// return if already contains same name variable.
Variable* CreateVariable(const std::string& name);
Variable* NewVar(const std::string& name);
private:
std::shared_ptr<Scope> parent_;
......@@ -107,11 +107,11 @@ class Scope {
```
## Only scope can create a variable
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`.
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `NewVar` can construct `Variable`.
## When scope destroyed, all variables inside this scope should be destroyed together
The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
The scope hold unique pointers for all variables. User can `FindVar` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
## Sharing a parent scope
......@@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar
## Orthogonal interface
`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily.
`FindVar` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `NewVar` will return a `Error` when there is a name conflict locally. Combine `FindVar` and `NewVar`, we can implement `NewVar` easily.
......@@ -49,6 +49,7 @@ message AttrProto {
message VarProto {
required string name = 1;
required string comment = 2;
required bool is_tensor = 3;
};
message OpProto {
......
......@@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习律或者对数据进行归一化处理。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
pip uninstall py_paddle paddle
然后安装paddle的python环境, 在build目录下执行
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
......@@ -68,7 +68,7 @@ As a simple example, consider the following:
1. **BLAS Dependencies(optional)**
CMake will search BLAS libraries from system. If not found, OpenBLAS will be downloaded, built and installed automatically.
CMake will search BLAS libraries from the system. If not found, OpenBLAS will be downloaded, built and installed automatically.
To utilize preinstalled BLAS, you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`.
```bash
......@@ -131,9 +131,9 @@ As a simple example, consider the following:
To build GPU version, you will need the following installed:
1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain
2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn)
4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on
......@@ -172,6 +172,7 @@ export PATH=<path to install>/bin:$PATH
# install PaddlePaddle Python modules.
sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
```
## <span id="centos">Build on Centos 7</span>
### Install Dependencies
......@@ -192,9 +193,9 @@ sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
To build GPU version, you will need the following installed:
1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain
2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn)
4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on
......@@ -222,7 +223,7 @@ mkdir build && cd build
```
Finally, you can build and install PaddlePaddle:
```bash
# you can add build option here, such as:
cmake3 .. -DCMAKE_INSTALL_PREFIX=<path to install>
......
......@@ -3,6 +3,43 @@ PaddlePaddle的Docker容器使用方式
PaddlePaddle目前唯一官方支持的运行的方式是Docker容器。因为Docker能在所有主要操作系统(包括Linux,Mac OS X和Windows)上运行。 请注意,您需要更改 `Dockers设置 <https://github.com/PaddlePaddle/Paddle/issues/627>`_ 才能充分利用Mac OS X和Windows上的硬件资源。
Docker使用入门
------------------------------
几个基础的概念帮助理解和使用Docker:
- *镜像*:一个Docker镜像是一个打包好的软件。它包含了这个软件本身和它所依赖的运行环境。PaddlePaddle的Docker镜像就包含了PaddlePaddle的Python库以及其依赖的多个Python库。这样我们可以直接在Docker中运行需要的程序而不需要安装后在执行。可以执行:
.. code-block:: bash
docker images
来列出当前系统中的所有镜像,同样可以执行:
.. code-block:: bash
docker pull paddlepaddle/paddle:0.10.0
来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。
- *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。
实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。
可以执行:
.. code-block:: bash
docker run paddlepaddle/paddle:0.10.0
来使用一个镜像启动一个容器。
- 默认情况下,Docker容器会运行在独立的文件系统空间之上,我们无法在Docker容器中
访问到主机上的文件。可以通过*挂载Volume*的方式,将主机上的文件或目录挂载到
Docker容器中。下面的命令把当前目录挂载到了容器中的 /data 目录下,容器使用
debian镜像,并且启动后执行 :code:`ls /data`。
.. code-block:: bash
docker run --rm -v $(pwd):/data debian ls /data
PaddlePaddle发布的Docker镜像使用说明
------------------------------
......@@ -12,11 +49,11 @@ PaddlePaddle需要的所有编译工具。把编译出来的PaddlePaddle也打
像,称为生产镜像,里面涵盖了PaddlePaddle运行所需的所有环境。每次
PaddlePaddle发布新版本的时候都会发布对应版本的生产镜像以及开发镜像。运
行镜像包括纯CPU版本和GPU版本以及其对应的非AVX版本。我们会在
`dockerhub.com <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_ 提供最新
的Docker镜像,可以在"tags"标签下找到最新的Paddle镜像版本。为了方便在国
内的开发者下载Docker镜像,我们提供了国内的镜像服务器供大家使用。如果您
在国内,请把文档里命令中的paddlepaddle/paddle替换成
docker.paddlepaddle.org/paddle。
`dockerhub.com <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_
和国内镜像`docker.paddlepaddle.org` 提供最新
的Docker镜像,可以在"tags"标签下找到最新的Paddle镜像版本。
**注意:为了方便在国内的开发者下载Docker镜像,我们提供了国内的镜像服务器供大家使用。如果您在国内,请把文档里命令中的paddlepaddle/paddle替换成docker.paddlepaddle.org/paddle。**
1. 开发镜像::code:`paddlepaddle/paddle:0.10.0-dev`
......@@ -37,13 +74,13 @@ docker.paddlepaddle.org/paddle。
.. code-block:: bash
docker run -it --rm paddlepaddle/paddle:0.10.0-dev /bin/bash
docker run -it --rm -v $(pwd):/paddle paddlepaddle/paddle:0.10.0-dev /bin/bash
或者,可以以后台进程方式运行容器:
.. code-block:: bash
docker run -d -p 2202:22 -p 8888:8888 paddledev/paddle:0.10.0-dev
docker run -d -p 2202:22 -p 8888:8888 -v $(pwd):/paddle paddlepaddle/paddle:0.10.0-dev /usr/sbin/sshd -D
然后用密码 :code:`root` SSH进入容器:
......@@ -68,6 +105,8 @@ docker.paddlepaddle.org/paddle。
如果输出是No,就需要选择使用no-AVX的镜像
**注:在0.10.0之后的版本,PaddlePaddle都可以自动判断硬件是否支持AVX,所以无需判断AVX即可使用**
以上方法在GPU镜像里也能用,只是请不要忘记提前在物理机上安装GPU最新驱动。
为了保证GPU驱动能够在镜像里面正常运行,我们推荐使用[nvidia-docker](https://github.com/NVIDIA/nvidia-docker)来运行镜像。
......
......@@ -63,12 +63,35 @@ CPU-only version and a CUDA GPU version and their no-AVX versions.
We put the docker images on `dockerhub.com
<https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_. You can find the
latest versions under "tags" tab at dockerhub.com. If you are in
China, you can use our Docker image registry mirror to speed up the
download process. To use it, please replace all paddlepaddle/paddle in
the commands to docker.paddlepaddle.org/paddle.
latest versions under "tags" tab at dockerhub.com.
1. Production images, this image might have multiple variants:
** NOTE: If you are in China, you can use our Docker image registry mirror to speed up the download process. To use it, please replace all paddlepaddle/paddle in the commands to docker.paddlepaddle.org/paddle.**
1. development image :code:`paddlepaddle/paddle:<version>-dev`
This image has packed related develop tools and runtime
environment. Users and developers can use this image instead of
their own local computer to accomplish development, build,
releasing, document writing etc. While different version of paddle
may depends on different version of libraries and tools, if you
want to setup a local environment, you must pay attention to the
versions. The development image contains:
- gcc/clang
- nvcc
- Python
- sphinx
- woboq
- sshd
Many developers use servers with GPUs, they can use ssh to login to
the server and run :code:`docker exec` to enter the docker
container and start their work. Also they can start a development
docker image with SSHD service, so they can login to the container
and start work.
2. Production images, this image might have multiple variants:
- GPU/AVX::code:`paddlepaddle/paddle:<version>-gpu`
- GPU/no-AVX::code:`paddlepaddle/paddle:<version>-gpu-noavx`
......@@ -84,7 +107,7 @@ the commands to docker.paddlepaddle.org/paddle.
if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi
**NOTE:versions after 0.10.0 will automatically detect system AVX support, so manual detect is not needed in this case.**
To run the CPU-only image as an interactive container:
.. code-block:: bash
......@@ -103,29 +126,6 @@ the commands to docker.paddlepaddle.org/paddle.
nvidia-docker run -it --rm paddlepaddle/paddle:0.10.0-gpu /bin/bash
2. development image :code:`paddlepaddle/paddle:<version>-dev`
This image has packed related develop tools and runtime
environment. Users and developers can use this image instead of
their own local computer to accomplish development, build,
releasing, document writing etc. While different version of paddle
may depends on different version of libraries and tools, if you
want to setup a local environment, you must pay attention to the
versions. The development image contains:
- gcc/clang
- nvcc
- Python
- sphinx
- woboq
- sshd
Many developers use servers with GPUs, they can use ssh to login to
the server and run :code:`docker exec` to enter the docker
container and start their work. Also they can start a development
docker image with SSHD service, so they can login to the container
and start work.
Train Model Using Python API
----------------------------
......
......@@ -13,22 +13,18 @@
# serve to show the default.
import sys
import os, subprocess
sys.path.insert(0, os.path.abspath('@PADDLE_SOURCE_DIR@/python'))
import shlex
from recommonmark import parser, transform
try:
import py_paddle
import paddle
import paddle.v2
except ImportError:
print("Must install paddle python package before generating documentation")
sys.exit(1)
import paddle
import paddle.v2
MarkdownParser = parser.CommonMarkParser
AutoStructify = transform.AutoStructify
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
templates_path = ["@PROJ_ROOT@/doc_theme/templates"]
templates_path = ["@PADDLE_SOURCE_DIR@/doc_theme/templates"]
# -- General configuration ------------------------------------------------
......@@ -124,7 +120,7 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['@PROJ_ROOT@/doc_theme/static']
html_static_path = ['@PADDLE_SOURCE_DIR@/doc_theme/static']
# Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc'
......
......@@ -13,15 +13,11 @@
# serve to show the default.
import sys
import os, subprocess
sys.path.insert(0, os.path.abspath('@PADDLE_SOURCE_DIR@/python'))
import shlex
from recommonmark import parser, transform
try:
import py_paddle
import paddle
import paddle.v2
except ImportError:
print("Must install paddle python package before generating documentation")
sys.exit(1)
import paddle
import paddle.v2
MarkdownParser = parser.CommonMarkParser
......@@ -29,7 +25,7 @@ AutoStructify = transform.AutoStructify
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
templates_path = ["@PROJ_ROOT@/doc_theme/templates"]
templates_path = ["@PADDLE_SOURCE_DIR@/doc_theme/templates"]
# -- General configuration ------------------------------------------------
......@@ -124,7 +120,7 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['@PROJ_ROOT@/doc_theme/static']
html_static_path = ['@PADDLE_SOURCE_DIR@/doc_theme/static']
# Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc'
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
......@@ -5,12 +19,15 @@ import (
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"strings"
"time"
"github.com/namsral/flag"
log "github.com/sirupsen/logrus"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
......@@ -20,11 +37,18 @@ func main() {
port := flag.Int("port", 8080, "port of the master server.")
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
level, e := log.ParseLevel(*logLevel)
candy.Must(e)
log.SetLevel(level)
if *endpoints == "" {
log.Warningln("-endpoints not set, fault tolerance not be enabled.")
}
......@@ -46,6 +70,20 @@ func main() {
store = &master.InMemStore{}
}
shutdown := func() {
log.Infoln("shutting down gracefully")
err := store.Shutdown()
if err != nil {
log.Errorln(err)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil {
log.Fatal(err)
......@@ -62,8 +100,12 @@ func main() {
log.Fatal(err)
}
err = http.Serve(l, nil)
if err != nil {
log.Fatal(err)
}
go func() {
err = http.Serve(l, nil)
if err != nil {
log.Fatal(err)
}
}()
<-c
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"time"
......@@ -16,10 +32,11 @@ import (
func main() {
port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls")
dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
......@@ -39,16 +56,34 @@ func main() {
if *index >= 0 {
idx = *index
} else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout)
idx, err = e.Register()
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
idx, err = e.Register(*port)
candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
log.Errorf("Fetch checkpoint failed, %s", err)
if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.")
} else {
panic(err)
}
}
}
shutdown := func() {
log.Infoln("shutting down gracefully")
sErr := e.Shutdown()
if sErr != nil {
log.Errorln(sErr)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err)
......@@ -59,7 +94,11 @@ func main() {
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
candy.Must(err)
log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil)
candy.Must(err)
go func() {
log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil)
candy.Must(err)
}()
<-c
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connection
import (
......
hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-07-11T10:04:40.786745417+08:00
hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-08-07T23:37:48.867469328Z
imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd
version: cb2a496c4ddd1c87a9f280e116649b599999ec79
version: d0d1a87aa96ae14914751d42264262cb69eda170
subpackages:
- alarm
- auth
- auth/authpb
- client
- clientv3
- clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages:
......@@ -17,14 +108,63 @@ imports:
- proto
- name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129
version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus
version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 1351f936d976c60a0a48d728281922cf63eafb8d
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
......@@ -36,11 +176,15 @@ imports:
- lex/httplex
- trace
- name: golang.org/x/sys
version: abf9c25f54453410d0c6668e519582a9e1115027
version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
repo: https://github.com/golang/sys.git
vcs: git
subpackages:
- unix
- name: golang.org/x/text
version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- secure/bidirule
- transform
......@@ -60,4 +204,18 @@ imports:
- stats
- tap
- transport
testImports: []
- name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
......@@ -6,8 +6,21 @@ import:
subpackages:
- clientv3
- clientv3/concurrency
- embed
- etcdserver
- package: github.com/namsral/flag
version: ^1.7.4-pre
- package: github.com/sirupsen/logrus
version: ^1.0.0
- package: github.com/topicai/candy
- package: golang.org/x/crypto
repo: https://github.com/golang/crypto.git
vcs: git
- package: golang.org/x/sys
repo: https://github.com/golang/sys.git
vcs: git
- package: golang.org/x/text
repo: https://github.com/golang/text.git
vcs: git
- package: github.com/satori/go.uuid
version: v1.1.0
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(master_test)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
go_library(paddle_master SHARED DEPS paddle_go_optimizer)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client;
*/
import "C"
......@@ -19,11 +35,9 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client
......@@ -52,32 +66,32 @@ func remove(client C.paddle_master_client) *master.Client {
}
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
endpoints := strings.Split(p, ",")
c, err := master.NewClient(
master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
master.WithBuffer(bufSize),
)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
}
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
if err != nil {
panic(err)
}
return add(c)
}
......@@ -86,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
......@@ -104,23 +124,28 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK
}
// return value:
// 0:ok
// -1:error
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r, err := c.NextRecord()
if err != nil {
// Error
// TODO: return the type of error?
*record = (*C.uchar)(nullPtr)
// NOTE: use errors to indicate pass ends
if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1
}
if len(r) == 0 {
// Empty record
*record = (*C.uchar)(nullPtr)
*record = (*C.uchar)(nil)
return 0
}
......@@ -130,6 +155,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return C.int(size)
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan record
conn *connection.Conn
ch chan record
bufSize int
}
type record struct {
......@@ -19,33 +36,104 @@ type record struct {
err error
}
// NewClient creates a new Client.
// WithBuffer sets the client to buffer the training record.
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addrCh <-chan string, bufSize int) *Client {
func WithBuffer(bufSize int) func(*Client) error {
return func(c *Client) error {
if bufSize <= 0 {
return nil
}
c.bufSize = bufSize
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
if err != nil {
return err
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh)
go c.getRecords()
return c
for _, opt := range opts {
err := opt(c)
if err != nil {
return nil, err
}
}
c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil
}
// StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}
func (c *Client) getRecords() {
func (c *Client) getRecords(passID int) {
for {
t, err := c.getTask()
t, err := c.getTask(passID)
if err != nil {
// TODO(helin): wait before move on with next
// getTask call.
log.Errorln(err)
continue
if err.Error() == ErrPassBefore.Error() ||
err.Error() == ErrNoMoreAvailable.Error() ||
err.Error() == ErrAllTaskFailed.Error() {
c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue
}
log.Errorf("getTask error: %s", err)
}
for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
f, e := os.Open(chunk.Path)
if e != nil {
log.Errorln(e)
continue
}
......@@ -68,7 +156,10 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.Meta.ID)
err = c.taskFinished(t.Meta.ID)
if err != nil {
log.Errorln(err)
}
}
}
......@@ -98,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
}
}
// SetDataset set dataset for the master server to dispatch.
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
// After all tasks are done, another call of SetDataset will start another pass.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
}
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
func (c *Client) getTask(passID int) (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
err := c.conn.Call("Service.GetTask", passID, &t)
return t, err
}
......@@ -131,3 +225,11 @@ func (c *Client) NextRecord() ([]byte, error) {
r := <-c.ch
return r.r, r.err
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
......@@ -40,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
panic(err)
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if sErr != nil {
panic(sErr)
}
server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
sErr = server.Register(s)
if sErr != nil {
panic(sErr)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
sErr = http.Serve(l, mux)
if sErr != nil {
panic(sErr)
}
}(l)
......@@ -66,11 +80,21 @@ func TestGetFinishTask(t *testing.T) {
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
_, err = w.Write(nil)
if err != nil {
panic(err)
}
// call Close to force RecordIO writing a chunk.
w.Close()
err = w.Close()
if err != nil {
panic(err)
}
}
err = f.Close()
if err != nil {
panic(err)
}
f.Close()
// Manually intialize client to avoid calling c.getRecords()
c := &Client{}
......@@ -79,48 +103,56 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
c.SetDataset([]string{path})
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
task, cErr := c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("error: %v, pass: %d\n", cErr, i)
}
tasks = append(tasks, task)
}
_, err = c.getTask()
if err == nil {
// getting task before task finishes should return error
_, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}
err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(tasks[0].Meta.ID)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
_, cErr = c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(task.Meta.ID)
if cErr != nil {
t.Fatal(cErr)
}
}
}
for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i)
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master_test
import (
......@@ -6,8 +20,10 @@ import (
"net/http"
"net/rpc"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
......@@ -15,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio"
)
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) {
const (
path = "/tmp/master_client_TestFull"
......@@ -31,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil {
panic(err)
}
......@@ -55,32 +83,67 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
w := recordio.NewWriter(f, -1, -1)
w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ {
w.Write([]byte{byte(i)})
_, err = w.Write([]byte{byte(i)})
if err != nil {
panic(err)
}
}
err = w.Close()
if err != nil {
panic(err)
}
w.Close()
f.Close()
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}
if len(r) != 1 {
t.Fatal(pass, i, "Length should be 1.", r)
err = f.Close()
if err != nil {
panic(err)
}
// start several client to test task fetching
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
}
e = c.SetDataset([]string{path})
if e != nil {
panic(e)
}
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
if received[r[0]] {
t.Fatal(pass, i, "Received duplicate.", received, r)
received := make(map[byte]bool)
taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
}
received[r[0]] = true
}
}()
}
wg.Wait()
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
......@@ -25,15 +39,12 @@ type EtcdClient struct {
statePath string
client *clientv3.Client
lock *concurrency.Mutex
sess *concurrency.Session
}
// NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints)
// TODO(helin): gracefully shutdown etcd store. Becuase etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: dialTimeout,
......@@ -53,14 +64,14 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath)
log.Infof("Trying to acquire lock at %s.", lockPath)
err = lock.Lock(context.TODO())
if err != nil {
return nil, err
}
log.Debugf("Successfully acquired lock at %s.", lockPath)
log.Infof("Successfully acquired lock at %s.", lockPath)
put := clientv3.OpPut(addrPath, string(addr))
put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
if err != nil {
return nil, err
......@@ -75,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
statePath: statePath,
client: cli,
lock: lock,
sess: sess,
}
return e, nil
......@@ -143,9 +155,24 @@ func (e *EtcdClient) Load() ([]byte, error) {
return state, nil
}
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
err := e.sess.Close()
newErr := e.client.Close()
if newErr != nil {
if err == nil {
err = newErr
} else {
log.Errorln(newErr)
}
}
return err
}
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
......@@ -159,8 +186,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
// watchKey watches the specify key and send to valChan if there is some event.
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import "sync"
// InMemStore is an in memory implementation of Store interface.
//
// It does not tolerate the fault that casues the program to crash.
// It does not tolerate the fault that causes the program to crash.
type InMemStore struct {
mu sync.Mutex
buf []byte
......@@ -26,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
return m.buf, nil
}
// Shutdown shuts down the in mem store.
func (m *InMemStore) Shutdown() error {
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
......@@ -5,6 +19,7 @@ import (
"compress/gzip"
"encoding/gob"
"errors"
"math/rand"
"os"
"path/filepath"
"sync"
......@@ -19,10 +34,23 @@ const (
dialTimeout = 5 * time.Second
)
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state.
type Store interface {
Save([]byte) error
Load() ([]byte, error)
Shutdown() error
}
// Chunk is a chunk of data consisted of several data instances.
......@@ -49,11 +77,12 @@ type taskEntry struct {
NumFailure int
}
type taskQueues struct {
type masterState struct {
Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry
Failed []taskEntry
CurPass int
}
// Service is the master server service.
......@@ -61,16 +90,26 @@ type Service struct {
chunksPerTask int
timeoutDur time.Duration
failureMax int
ready chan struct{}
store Store
mu sync.Mutex
initDone bool
taskQueues taskQueues
ready chan struct{}
initDone bool
mu sync.Mutex
// State to be persisted to snapshot.
state masterState
// The trainer that is currently saving model. This state is
// transient, does not need to be persisted to snapshot.
savingTrainer string
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0
// generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 {
chunksPerTask = 1
}
......@@ -80,7 +119,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id
id++
counter++
id = timestamp + randStart + counter
result = append(result, cur)
cur.Task.Chunks = nil
}
......@@ -102,8 +142,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur
s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur
s.failureMax = failureMax
s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry)
s.state = masterState{}
s.state.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{})
s.store = store
recovered, err := s.recover()
......@@ -141,7 +181,7 @@ func (s *Service) recover() (bool, error) {
}
dec := gob.NewDecoder(gr)
var tqs taskQueues
var tqs masterState
err = dec.Decode(&tqs)
if err != nil {
return false, err
......@@ -154,13 +194,18 @@ func (s *Service) recover() (bool, error) {
log.Errorln(err)
}
s.taskQueues = tqs
s.state = tqs
log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.")
for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
}
return true, nil
}
// snapshot *must* be called with s.mu being held.
func (s *Service) snapshot() error {
// TOOD(helin): etcd request has a size limit, so the snapshot
// TODO(helin): etcd request has a size limit, so the snapshot
// size is limited by the max request size. We should either
// divide the snapshot into smaller chunks and save under
// different keys, or configure the request size to be big
......@@ -169,7 +214,7 @@ func (s *Service) snapshot() error {
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
enc := gob.NewEncoder(gw)
err := enc.Encode(s.taskQueues)
err := enc.Encode(s.state)
if err != nil {
return err
}
......@@ -215,6 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
}
count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count)
for i := 0; i < count; i++ {
chunk := Chunk{
Path: path,
......@@ -231,7 +277,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
......@@ -250,19 +296,20 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return err
}
s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
s.state.Todo = partition(chunks, s.chunksPerTask)
err = s.snapshot()
if err != nil {
log.Errorln(err)
return err
}
close(s.ready)
s.initDone = true
return nil
}
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
......@@ -277,17 +324,17 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
}
}()
delete(s.taskQueues.Pending, t.Task.Meta.ID)
delete(s.state.Pending, t.Task.Meta.ID)
t.NumFailure++
if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t)
s.state.Failed = append(s.state.Failed, t)
return
}
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.state.Todo = append(s.state.Todo, t)
return
}
......@@ -296,7 +343,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
t, ok := s.state.Pending[taskID]
if !ok {
return
}
......@@ -308,51 +355,45 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
// must be called with lock held.
func (s *Service) logFields() log.Fields {
return log.Fields{
"todoLen": len(s.taskQueues.Todo),
"pendingLen": len(s.taskQueues.Pending),
"doneLen": len(s.taskQueues.Done),
"failedLen": len(s.taskQueues.Failed),
"todoLen": len(s.state.Todo),
"pendingLen": len(s.state.Pending),
"doneLen": len(s.state.Done),
"failedLen": len(s.state.Failed),
"curPass": s.state.CurPass,
}
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
// passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
if passID < s.state.CurPass {
return ErrPassBefore
}
if passID > s.state.CurPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
err := errors.New("all task failed")
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
if len(s.state.Todo) == 0 {
if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
return ErrAllTaskFailed
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
log.WithFields(s.logFields()).Warningln("No more available task.")
return ErrNoMoreAvailable
}
t := s.taskQueues.Todo[0]
t := s.state.Todo[0]
t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.Meta.ID] = t
s.state.Todo = s.state.Todo[1:]
s.state.Pending[t.Task.Meta.ID] = t
err := s.snapshot()
if err != nil {
return err
......@@ -374,7 +415,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
t, ok := s.state.Pending[taskID]
if !ok {
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return nil
......@@ -382,15 +423,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
// task finished, reset timeout
t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID)
s.state.Done = append(s.state.Done, t)
delete(s.state.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
// increase master side pass count if all tasks finished
s.state.CurPass++
s.state.Todo = append(s.state.Done, s.state.Failed...)
s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.state.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass)
}
err := s.snapshot()
......@@ -409,7 +453,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[meta.ID]
t, ok := s.state.Pending[meta.ID]
if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil
......@@ -418,3 +462,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.processFailedTask(t, meta.Epoch)
return nil
}
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import "testing"
......@@ -30,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
if ts[i].Task.Meta.ID != i {
// test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i)
}
}
......
package master_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
t.Fatal(err)
}
_, err = master.NewService(store, 10, 10, 3)
if err != nil {
t.Fatal(err)
}
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: 3 * time.Second,
})
if err != nil {
t.Fatal(err)
}
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
if err != nil {
t.Fatal(err)
}
if err := cli.Close(); err != nil {
t.Fatal(err)
}
// test master process registry itself into etcd server.
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m)
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
/*
......@@ -34,7 +48,6 @@ import (
log "github.com/sirupsen/logrus"
)
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client
......@@ -42,10 +55,10 @@ var curHandle C.paddle_pserver_client
func add(c *client.Client) C.paddle_pserver_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
cli := curHandle
curHandle++
handleMap[client] = c
return client
handleMap[cli] = c
return cli
}
func get(client C.paddle_pserver_client) *client.Client {
......@@ -63,7 +76,7 @@ func remove(client C.paddle_pserver_client) *client.Client {
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
if p == nil {
return nil
}
......@@ -77,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
type selector bool
func (s selector) Select() bool {
return bool(s)
func (s selector) Select() (bool, error) {
return bool(s), nil
}
func (s selector) Done() error {
return nil
}
type lister []client.Server
......@@ -101,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcd_endpoints)
etcd_client := client.NewEtcd(addr)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0))
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
addr := C.GoString(etcdEndpoints)
etcdClient := client.NewEtcd(addr)
c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
return add(c)
}
......@@ -114,30 +130,41 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client)
}
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
if selected := c.BeginInitParams(); selected {
selected, err := c.BeginInitParams()
if err != nil {
panic(err)
}
if selected {
return 1
}
return C.PSERVER_OK
return 0
}
//export paddle_init_param
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, paramConfig unsafe.Pointer, configLen C.int) C.int {
et := pserver.ElementType(param.element_type)
name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
pc := pserver.ParameterWithConfig{
Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
Config: cArrayToSlice(param_config, int(config_len)),
Config: cArrayToSlice(paramConfig, int(configLen)),
}
c := get(client)
err := c.InitParam(pc)
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name)
log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name)
return C.PSERVER_OK
}
log.Errorln(err)
......@@ -153,7 +180,7 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams()
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Warningln("parameters already initialized, treat paddle_finish_init_params as sucessful.")
log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK
}
......@@ -223,12 +250,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nullPtr {
if unsafe.Pointer(param) == nil {
log.Errorln("must pre-allocate parameter.")
return C.PSERVER_ERROR
}
if unsafe.Pointer(param.content) != nullPtr {
if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) {
log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
return C.PSERVER_ERROR
......@@ -243,17 +270,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return C.PSERVER_OK
}
//export paddle_save_model
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
if err != nil {
log.Errorln(err)
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
func main() {} // Required but ignored
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer)
add_style_check_target(test_cclient test_cclient.c)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h>
#include <stdlib.h>
......@@ -97,9 +111,5 @@ retry:
getParams(c);
}
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0;
}
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master
import os
import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoints = "http://" + etcd_ip + ":2379"
print "etcd endpoints: ", etcd_endpoints
def main():
......@@ -20,19 +28,20 @@ def main():
parameters = paddle.parameters.create(cost)
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec="localhost:3000")
pserver_spec=etcd_endpoints,
use_etcd=True)
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
......@@ -47,10 +56,14 @@ def main():
print "Test %d, %.2f" % (event.pass_id, result.cost)
# training
# NOTE: use uci_housing.train() as reader for non-paddlecloud training
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
uci_housing.train(), buf_size=500),
cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2),
feeding={'x': 0,
'y': 1},
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
......@@ -13,9 +27,13 @@ import (
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers.
// Selector selects if the client should initialize parameters and
// reports the initialization process done.
type Selector interface {
Select() bool
// Select selects if the client should initialize parameter servers.
Select() (bool, error)
// Done indicates the initialization process is done.
Done() error
}
// Server is the identification of a parameter Server.
......@@ -101,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func (c *Client) BeginInitParams() bool {
func (c *Client) BeginInitParams() (bool, error) {
return c.sel.Select()
}
......@@ -205,35 +223,9 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return ps, nil
}
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
err := p.Call("Service.Save", path, nil)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
_, _ = h.Write([]byte(s))
return h.Sum32()
}
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client_test
import (
"context"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"sync"
"testing"
"time"
......@@ -43,7 +59,7 @@ func initClient() [numPserver]int {
go func(l net.Listener) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
panic(err)
}
......@@ -77,21 +93,43 @@ func initEtcdClient() {
log.Errorf("err %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath)
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
_, err = client.Delete(ctx, pserver.PsDesired)
if err != nil {
panic(err)
}
_, err = client.Delete(ctx, pserver.PsPath)
if err != nil {
panic(err)
}
_, err = client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
if err != nil {
panic(err)
}
ports := initClient()
for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
_, err = client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
if err != nil {
panic(err)
}
}
cancel()
client.Close()
err = client.Close()
if err != nil {
panic(err)
}
}
type selector bool
func (s selector) Select() bool {
return bool(s)
func (s selector) Select() (bool, error) {
return bool(s), nil
}
func (s selector) Done() error {
return nil
}
type lister []client.Server
......@@ -100,27 +138,38 @@ func (l lister) List() []client.Server {
return l
}
func ClientTest(t *testing.T, c *client.Client) {
selected := c.BeginInitParams()
func testClient(t *testing.T, c *client.Client) {
selected, err := c.BeginInitParams()
if err != nil {
t.Fatal(err)
}
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 100
const numParameter = 1000
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
var wg sync.WaitGroup
for i := 0; i < numParameter; i++ {
var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func(i int) {
var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
if err != nil {
t.Fatal(err)
}
wg.Done()
}(i)
}
wg.Wait()
err = c.FinishInitParams()
if err != nil {
......@@ -128,7 +177,7 @@ func ClientTest(t *testing.T, c *client.Client) {
}
var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ {
for i := 0; i < numParameter; i++ {
var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32
......@@ -136,9 +185,31 @@ func ClientTest(t *testing.T, c *client.Client) {
grads = append(grads, g)
}
err = c.SendGrads(grads)
if err != nil {
t.Fatal(err)
const paramPerGroup = 10
const numGroups = numParameter / paramPerGroup
// shuffle send grads order
for i := range grads {
j := rand.Intn(i + 1)
grads[i], grads[j] = grads[j], grads[i]
}
for i := 0; i < numGroups; i++ {
var gs []pserver.Gradient
if i == numGroups-1 {
gs = grads[i*paramPerGroup:]
} else {
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(gs []pserver.Gradient) {
err := c.SendGrads(gs)
if err != nil {
t.Fatal(err)
}
wg.Done()
}(gs)
}
names := make([]string, numParameter)
......@@ -146,20 +217,35 @@ func ClientTest(t *testing.T, c *client.Client) {
names[i] = "p_" + strconv.Itoa(i)
}
params, err := c.GetParams(names)
if err != nil {
t.Fatal(err)
}
for i := 0; i < numGroups; i++ {
var ns []string
if i == numGroups-1 {
ns = names[i*paramPerGroup:]
} else {
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
}
if len(names) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
}
wg.Add(1)
go func(ns []string) {
params, err := c.GetParams(ns)
if err != nil {
t.Fatal(err)
}
for i := range params {
if names[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
}
if len(ns) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
}
for i := range params {
if ns[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
}
}
wg.Done()
}(ns)
}
wg.Wait()
}
func TestNativeClient(t *testing.T) {
......@@ -169,13 +255,14 @@ func TestNativeClient(t *testing.T) {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1)
testClient(t, c1)
}
// TODO: tmperary disable etcdClient test for dependency of etcd)
// EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
func EtcdClient(t *testing.T) {
initEtcdClient()
etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2)
testClient(t, c2)
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
const (
DefaultEtcdTimeout time.Duration = 5 * time.Second
defaultEtcdTimeout time.Duration = 5 * time.Second
initLockPath = "/init_ps/lock"
initDonePath = "/init_ps/done"
initDoneVal = "1"
)
// EtcdClient is used by pserver client that is a part of trainer process.
// Etcd is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type EtcdClient struct {
// 1. add watcher to watch the change state of pservers.
type Etcd struct {
client *clientv3.Client
timeout time.Duration
endpoints []string
lock *concurrency.Mutex
}
// Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int {
func (e *Etcd) Desired() int {
var psDesired int
for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("psDesired %s invalid %v", psDesired, err)
time.Sleep(p.timeout)
log.Errorf("psDesired %d invalid %v", psDesired, err)
time.Sleep(e.timeout)
continue
}
......@@ -59,26 +80,26 @@ func (p *EtcdClient) Desired() int {
}
// List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server {
psDesired := p.Desired()
func (e *Etcd) List() []Server {
psDesired := e.Desired()
servers := make([]Server, psDesired)
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
resp, err := e.client.Get(ctx, psKey)
cancel()
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
......@@ -86,10 +107,10 @@ func (p *EtcdClient) List() []Server {
// TODO(Longfei) check the ps address
if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout)
time.Sleep(e.timeout)
continue
}
log.Infof("got value (%s) for key: %s", psAddr, psKey)
log.Debugf("got value (%s) for key: %s", psAddr, psKey)
servers[i].Index = i
servers[i].Addr = psAddr
}
......@@ -99,27 +120,135 @@ func (p *EtcdClient) List() []Server {
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient {
func NewEtcd(endpoints string) *Etcd {
ep := strings.Split(endpoints, ",")
var cli *clientv3.Client
var err error
for {
cli, err = clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: DefaultEtcdTimeout,
DialTimeout: defaultEtcdTimeout,
})
if err != nil {
log.Errorf("Init etcd connection failed: %v", err)
time.Sleep(DefaultEtcdTimeout)
time.Sleep(defaultEtcdTimeout)
continue
}
break
}
log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{
client := &Etcd{
client: cli,
timeout: DefaultEtcdTimeout,
timeout: defaultEtcdTimeout,
endpoints: ep,
}
return client
}
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func (e *Etcd) Select() (bool, error) {
sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5))
if err != nil {
return false, err
}
lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath)
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err = lock.Lock(context.Background())
if err != nil {
return false, err
}
log.Infof("Successfully acquired lock at %s.", initLockPath)
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit()
cancel()
if err != nil {
return false, err
}
if !tresp.Succeeded {
return false, errors.New("no longer the owner of the lock")
}
resp := tresp.Responses[0].GetResponseRange()
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
log.Infoln("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
}
return false, nil
}
return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value)
}
// Done indicates the parameter initialization process is done.
func (e *Etcd) Done() error {
if e.lock == nil {
return errors.New("lock is nil, Done called unexpectedly")
}
put := clientv3.OpPut(initDonePath, initDoneVal)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
cancel()
if err != nil {
return err
}
if !tresp.Succeeded {
return errors.New("no longer the owner of the lock")
}
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
} else {
e.lock = nil
}
return nil
}
// Close closes the etcd client.
func (e *Etcd) Close() error {
var err error
if e.lock != nil {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err == nil {
e.lock = nil
}
}
cErr := e.client.Close()
if cErr != nil {
if err != nil {
log.Errorln(cErr)
return err
}
return cErr
}
return err
}
package client_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func TestSelector(t *testing.T) {
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
var mu sync.Mutex
selectedCount := 0
var wg sync.WaitGroup
selectAndDone := func(c *client.Etcd) {
defer wg.Done()
selected, err := c.Select()
if err != nil {
panic(err)
}
if selected {
mu.Lock()
selectedCount++
mu.Unlock()
err = c.Done()
if err != nil {
t.Fatal(err)
}
}
}
c0 := client.NewEtcd(endpoint)
c1 := client.NewEtcd(endpoint)
c2 := client.NewEtcd(endpoint)
c3 := client.NewEtcd(endpoint)
wg.Add(3)
go selectAndDone(c0)
go selectAndDone(c1)
go selectAndDone(c2)
wg.Wait()
// simulate trainer crashed and restarted after the
// initialization process.
wg.Add(1)
go selectAndDone(c3)
wg.Wait()
mu.Lock()
if selectedCount != 1 {
t.Fatal("selected count wrong:", selectedCount)
}
mu.Unlock()
err = c0.Close()
if err != nil {
t.Fatal(err)
}
err = c1.Close()
if err != nil {
t.Fatal(err)
}
err = c2.Close()
if err != nil {
t.Fatal(err)
}
err = c3.Close()
if err != nil {
t.Fatal(err)
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
import (
......@@ -20,16 +34,19 @@ const (
PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
retryTimeout = 5 * time.Second
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
numPservers int
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
numPservers int
endpoints string
client *clientv3.Client
sess *concurrency.Session
dialTimeout time.Duration
ttlSec int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
// desired number of pservers in the job.
......@@ -38,19 +55,19 @@ type EtcdClient struct {
}
// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
return &EtcdClient{
etcdTimeout: timeout,
numPservers: numPservers,
etcdEndpoints: endpoints,
dialTimeout: dialtimeout,
ttlSec: ttlSec,
numPservers: numPservers,
endpoints: endpoints,
}
}
// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {
func (e *EtcdClient) Register(port int) (int, error) {
var err error
e.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
......@@ -58,19 +75,26 @@ func (e *EtcdClient) Register() (int, error) {
}
// initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",")
ep := strings.Split(e.endpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: e.etcdTimeout,
DialTimeout: e.dialTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout)
time.Sleep(retryTimeout)
continue
}
e.etcdClient = cli
log.Debugf("inited client to %s", e.etcdEndpoints)
e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil {
log.Errorf("create etcd session error: %v", err)
time.Sleep(retryTimeout)
continue
}
e.sess = sess
log.Debugf("inited client to %s", e.endpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
......@@ -81,7 +105,7 @@ func (e *EtcdClient) Register() (int, error) {
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
time.Sleep(retryTimeout)
continue
}
break
......@@ -92,18 +116,18 @@ func (e *EtcdClient) Register() (int, error) {
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired)
resp, err := e.client.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout)
time.Sleep(retryTimeout)
continue
}
if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout)
time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change
continue
}
......@@ -116,11 +140,11 @@ func (e *EtcdClient) Register() (int, error) {
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error
pserverIdx, err = e.registerPserverEtcd(ctx)
pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
time.Sleep(retryTimeout)
continue
}
break
......@@ -130,19 +154,19 @@ func (e *EtcdClient) Register() (int, error) {
}
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
_, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i)
......@@ -151,35 +175,20 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
log.Debug("register finished")
idx = i
registered = true
break
}
}
if registered == true {
if registered {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
return errors.New("not registered, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil {
......@@ -192,11 +201,12 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key)
resp, err := e.client.Get(ctx, key)
cancel()
if err != nil {
return []byte{}, err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return []byte{}, nil
......@@ -206,12 +216,34 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
}
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value))
var err error
if withLease {
_, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
} else {
_, err = e.client.Put(ctx, key, string(value))
}
cancel()
if err != nil {
return err
return err
}
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
var err error
if e.sess != nil {
err = e.sess.Close()
}
if e.client != nil {
newErr := e.client.Close()
if newErr != nil {
if err != nil {
log.Errorln(newErr)
} else {
err = newErr
}
}
}
return nil
return err
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
// #cgo CFLAGS: -I ../../
......@@ -14,15 +28,15 @@ import (
log "github.com/sirupsen/logrus"
)
var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct {
opt *C.struct_paddle_optimizer
elementType ElementType
contentLen int
config []byte
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
if p == nil {
return nil
}
......@@ -37,10 +51,11 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType
o.contentLen = len(paramWithConfigs.Param.Content)
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := State
paramBufferSize := C.size_t(len(p.Content) / C.sizeof_float)
paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": paramBufferSize,
......@@ -56,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0])
}
o.config = c
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o
......@@ -78,7 +94,11 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float)
if o.contentLen != len(g.Content) {
return fmt.Errorf("Name: %s, parameter and gradient does not have same content len, parameter: %d, gradient: %d", g.Name, o.contentLen, len(g.Content))
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r)
}
......@@ -86,8 +106,8 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
}
func (o *optimizer) Cleanup() {
if unsafe.Pointer(o.opt) != nullPtr {
if unsafe.Pointer(o.opt) != nil {
C.paddle_release_optimizer(o.opt)
o.opt = (*C.struct_paddle_optimizer)(nullPtr)
o.opt = (*C.struct_paddle_optimizer)(nil)
}
}
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
import (
......
此差异已折叠。
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver_test
import (
......@@ -16,7 +30,7 @@ const (
func TestServiceFull(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
t.Error(err)
}
......@@ -31,7 +45,7 @@ func TestServiceFull(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
var p1 pserver.Parameter
......@@ -40,40 +54,40 @@ func TestServiceFull(t *testing.T) {
p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
var param pserver.Parameter
err = s.GetParam("param_b", &param)
if err != nil {
t.FailNow()
t.Fatal(err)
}
if !reflect.DeepEqual(param, p1) {
t.FailNow()
t.Fatal("not equal:", param, p1)
}
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.SendGrad(g2, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
var param1 pserver.Parameter
err = s.GetParam("param_a", &param1)
if err != nil {
t.FailNow()
t.Fatal(err)
}
// don't compare content, since it's already changed by
......@@ -82,39 +96,39 @@ func TestServiceFull(t *testing.T) {
p.Content = nil
if !reflect.DeepEqual(param1, p) {
t.FailNow()
t.Fatal("not equal:", param1, p)
}
}
func TestMultipleInit(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
t.Error(err)
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err.Error() != pserver.AlreadyInitialized {
t.FailNow()
t.Fatal(err)
}
}
func TestUninitialized(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.FailNow()
t.Fatal(err)
}
}
func TestBlockUntilInitialized(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil {
t.Error(err)
}
......@@ -154,12 +168,12 @@ func TestBlockUntilInitialized(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
wg.Wait()
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(network_helper_test)
endif()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper
import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper
import "testing"
......
......@@ -21,22 +21,15 @@
#
# It same as PYTHONPATH=${YOUR_PYTHON_PATH}:$PYTHONPATH {exec...}
#
if ! python -c "import paddle" >/dev/null 2>/dev/null; then
PYPATH=""
set -x
while getopts "d:" opt; do
case $opt in
d)
PYPATH=$OPTARG
;;
esac
done
shift $(($OPTIND - 1))
export PYTHONPATH=$PYPATH:$PYTHONPATH
$@
else
echo "paddle package is already in your PYTHONPATH. But unittest need a clean environment."
echo "Please uninstall paddle package before start unittest. Try to 'pip uninstall paddle'"
exit 1
fi
PYPATH=""
set -x
while getopts "d:" opt; do
case $opt in
d)
PYPATH=$OPTARG
;;
esac
done
shift $(($OPTIND - 1))
export PYTHONPATH=$PYPATH:$PYTHONPATH
$@
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double Evaluator::getValue(const std::string name) const {
paddle::Error err;
double v = m->rawPtr->getValue(name, &err);
if (err) {
if (!err.isOK()) {
throw std::runtime_error(err.msg());
}
return v;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册