提交 fb413508 编写于 作者: L liaogang

Merge conflict

| Github account | name | | Github account | name |
|---|---| |---|---|
| reyoung | Yang Yu | | backyes | Yan-Fei Wang |
| beckett1124 | Bin Qi |
| Canpio | Jia-Yi Feng |
| chengxiaohua1105 | Xiao-Hua Cheng |
| cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang |
| cxysteven | Xing-Yi Cheng |
| dzhwinter | Zhi-Hong Dong |
| emailweixu | Wei Xu |
| gangliao | Gang Liao | | gangliao | Gang Liao |
| luotao01 | Tao Luo | | gongweibao | Wei-Bao Gong |
| jacquesqiao | Long-Fei Qiao | | Guo Sheng | Sheng Guo |
| qingqing01 | Qing-Qing Dang | | Haichao-Zhang | Hai-Chao Zhang |
| hedaoyuan | Dao-Yuan He | | hedaoyuan | Dao-Yuan He |
| wangyang59 | Yang Wang | | helinwang | He-Lin Wang |
| jacquesqiao | Long-Fei Qiao |
| kuke | Yi-Bing Liu |
| lcy-seso | Ying Cao |
| lipeng-unisound | Peng Li |
| liuyuan | Yuan Liu |
| livc | Zhao Li |
| llxxxll | Yong-Feng Liu |
| luotao01 | Tao Luo |
| lzhao4ever | Liang Zhao |
| NHZlX | Zhao-Long Xing |
| pakchoi | Chuan-Jiang Song |
| pengli09 | Peng Li |
| pkuyym | Ya-Ming Yang |
| QiJune | Jun Qi | | QiJune | Jun Qi |
| qingqing01 | Qing-Qing Dang |
| reyoung | Yang Yu |
| Superjom | Chun-Wei Yan |
| tianbingsz | Tian-Bing Xu | | tianbingsz | Tian-Bing Xu |
| cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang |
| typhoonzero | Yi Wu | | typhoonzero | Yi Wu |
| backyes | Yan-Fei Wang | | wanghaoshuang | Hao-Shuang Wang |
| pengli09 | Peng Li | | wangyang59 | Yang Wang |
| livc | Zhao Li | | wangzhen-nlp | Zhen Wang |
| wen-bo-yang | Wen-Bo Yang |
| wwhu | Wei-Wei Hu |
| xinghai-sun | Xing-Hai Sun |
| Xreki | Yi-Qun Liu | | Xreki | Yi-Qun Liu |
| xujun05 | Jun Xu |
| xushaoyong | Shao-Yong Xu |
| Yancey1989 | Xu Yan | | Yancey1989 | Xu Yan |
| emailweixu | Wei Xu | | zhaopu7 | Pu Zhao |
| wen-bo-yang | Wen-Bo Yang |
| helinwang | He-Lin Wang |
| lcy-seso | Ying Cao |
| Zrachel | Rui-Qing Zhang |
| Haichao-Zhang | Hai-Chao Zhang |
| gongweibao | Wei-Bao Gong |
| lzhao4ever | Liang Zhao |
| zhouxiao-coder | Xiao Zhou | | zhouxiao-coder | Xiao Zhou |
| lipeng-unisound | Peng Li | | Zrachel | Rui-Qing Zhang |
...@@ -49,6 +49,7 @@ option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) ...@@ -49,6 +49,7 @@ option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF) option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF)
option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
...@@ -96,6 +97,8 @@ include(external/warpctc) # download, build, install warpctc ...@@ -96,6 +97,8 @@ include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any include(external/any) # download libn::any
include(external/eigen) # download eigen3 include(external/eigen) # download eigen3
include(cudnn) # set cudnn libraries, must before configure
include(configure) # add paddle env configuration
include(generic) # simplify cmake module include(generic) # simplify cmake module
include(package) # set paddle packages include(package) # set paddle packages
include(cpplint) # set paddle c++ style include(cpplint) # set paddle c++ style
...@@ -103,15 +106,14 @@ include(ccache) # set ccache for compilation ...@@ -103,15 +106,14 @@ include(ccache) # set ccache for compilation
include(util) # set unittest and link libs include(util) # set unittest and link libs
include(rdma) # set rdma libraries include(rdma) # set rdma libraries
include(flags) # set paddle compile flags include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries
include(version) # set PADDLE_VERSION include(version) # set PADDLE_VERSION
include(coveralls) # set code coverage include(coveralls) # set code coverage
include(configure) # add paddle env configuration
include_directories("${PROJ_ROOT}") include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
include_directories(${Boost_INCLUDE_DIRS}) include_directories(${Boost_INCLUDE_DIRS})
set(EXTERNAL_LIBS set(EXTERNAL_LIBS
...@@ -139,8 +141,7 @@ add_subdirectory(proto) ...@@ -139,8 +141,7 @@ add_subdirectory(proto)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be # "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it. # placed after this block, because they depends on it.
if(WITH_GOLANG) if(WITH_GOLANG)
add_subdirectory(go/master/c) add_subdirectory(go)
add_subdirectory(go/pserver/cclient)
endif(WITH_GOLANG) endif(WITH_GOLANG)
add_subdirectory(paddle) add_subdirectory(paddle)
......
...@@ -34,14 +34,18 @@ RUN apt-get update && \ ...@@ -34,14 +34,18 @@ RUN apt-get update && \
net-tools && \ net-tools && \
apt-get clean -y apt-get clean -y
# Install Go # Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \ RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \ tar -C /usr/local -xzf go.tgz && \
mkdir /root/gopath && \ mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT. # should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# install glide
RUN curl -q https://glide.sh/get | sh
# git credential to skip password typing # git credential to skip password typing
RUN git config --global credential.helper store RUN git config --global credential.helper store
...@@ -57,7 +61,7 @@ RUN pip install --upgrade pip && \ ...@@ -57,7 +61,7 @@ RUN pip install --upgrade pip && \
pip install -U docopt PyYAML sphinx && \ pip install -U docopt PyYAML sphinx && \
pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \ pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \
pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \ pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install rarfile pip install rarfile
# To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use # To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -79,6 +79,9 @@ if(WITH_GOLANG) ...@@ -79,6 +79,9 @@ if(WITH_GOLANG)
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go") set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH}) file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle/Paddle") set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle/Paddle")
file(MAKE_DIRECTORY "${PADDLE_IN_GOPATH}")
set(PADDLE_GO_PATH "${CMAKE_SOURCE_DIR}/go")
add_custom_target(go_path) add_custom_target(go_path)
add_custom_command(TARGET go_path add_custom_command(TARGET go_path
# Symlink Paddle directory into GOPATH # Symlink Paddle directory into GOPATH
...@@ -89,7 +92,22 @@ if(WITH_GOLANG) ...@@ -89,7 +92,22 @@ if(WITH_GOLANG)
# We can't run `go get -d ./...` for every target, because # We can't run `go get -d ./...` for every target, because
# multiple `go get` can not run concurrently, but make need to be # multiple `go get` can not run concurrently, but make need to be
# able to run with multiple jobs. # able to run with multiple jobs.
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ./go/...
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
) )
if (GLIDE_INSTALL)
if(EXISTS $ENV{GOPATH}/bin/glide)
set(GLIDE "$ENV{GOPATH}/bin/glide")
else()
message(FATAL_ERROR "no glide executeble found: $ENV{GOPATH}/bin/glide")
endif()
add_custom_target(go_vendor)
add_custom_command(TARGET go_vendor
COMMAND env GOPATH=${GOPATH} ${GLIDE} install
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go"
)
add_dependencies(go_vendor go_path)
endif()
endif(WITH_GOLANG) endif(WITH_GOLANG)
...@@ -25,6 +25,7 @@ set(STYLE_FILTER "${STYLE_FILTER}-readability/casting") ...@@ -25,6 +25,7 @@ set(STYLE_FILTER "${STYLE_FILTER}-readability/casting")
set(IGNORE_PATTERN set(IGNORE_PATTERN
.*ImportanceSampler.* .*ImportanceSampler.*
.*cblas\\.h.* .*cblas\\.h.*
.*\\.pb\\.txt
.*LtrDataProvider.* .*LtrDataProvider.*
.*MultiDataProvider.*) .*MultiDataProvider.*)
......
...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject) ...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject)
SET(ANY_SOURCE_DIR ${THIRD_PARTY_PATH}/any) SET(ANY_SOURCE_DIR ${THIRD_PARTY_PATH}/any)
INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/linb_any) INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/extern_lib_any)
ExternalProject_Add( ExternalProject_Add(
linb_any extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/thelink2012/any.git" GIT_REPOSITORY "https://github.com/thelink2012/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020" GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020"
...@@ -17,5 +17,15 @@ ExternalProject_Add( ...@@ -17,5 +17,15 @@ ExternalProject_Add(
TEST_COMMAND "" TEST_COMMAND ""
) )
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(lib_any STATIC ${dummyfile})
else()
add_library(lib_any INTERFACE)
endif()
add_dependencies(lib_any extern_lib_any)
add_definitions(-DANY_IMPL_ANY_CAST_MOVEABLE) add_definitions(-DANY_IMPL_ANY_CAST_MOVEABLE)
LIST(APPEND external_project_dependencies linb_any) LIST(APPEND external_project_dependencies lib_any)
\ No newline at end of file
...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject) ...@@ -2,10 +2,10 @@ INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3) INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add( ExternalProject_Add(
eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website # for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
...@@ -26,4 +26,14 @@ ExternalProject_Add( ...@@ -26,4 +26,14 @@ ExternalProject_Add(
TEST_COMMAND "" TEST_COMMAND ""
) )
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_eigen3 = \"${dummyfile}\";")
add_library(eigen3 STATIC ${dummyfile})
else()
add_library(eigen3 INTERFACE)
endif()
add_dependencies(eigen3 extern_eigen3)
LIST(APPEND external_project_dependencies eigen3) LIST(APPEND external_project_dependencies eigen3)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# generic.cmake defines CMakes functions that look like Bazel's # generic.cmake defines CMakes functions that look like Bazel's
# building rules (https://bazel.build/). # building rules (https://bazel.build/).
# #
# #
# ------------------------------------------- # -------------------------------------------
# C++ CUDA C++ Go # C++ CUDA C++ Go
# ------------------------------------------- # -------------------------------------------
...@@ -25,51 +25,51 @@ ...@@ -25,51 +25,51 @@
# cc_binary nv_binary go_binary # cc_binary nv_binary go_binary
# cc_test nv_test go_test # cc_test nv_test go_test
# ------------------------------------------- # -------------------------------------------
# #
# To build a static library example.a from example.cc using the system # To build a static library example.a from example.cc using the system
# compiler (like GCC): # compiler (like GCC):
# #
# cc_library(example SRCS example.cc) # cc_library(example SRCS example.cc)
# #
# To build a static library example.a from multiple source files # To build a static library example.a from multiple source files
# example{1,2,3}.cc: # example{1,2,3}.cc:
# #
# cc_library(example SRCS example1.cc example2.cc example3.cc) # cc_library(example SRCS example1.cc example2.cc example3.cc)
# #
# To build a shared library example.so from example.cc: # To build a shared library example.so from example.cc:
# #
# cc_library(example SHARED SRCS example.cc) # cc_library(example SHARED SRCS example.cc)
# #
# To build a library using Nvidia's NVCC from .cu file(s), use the nv_ # To build a library using Nvidia's NVCC from .cu file(s), use the nv_
# prefixed version: # prefixed version:
# #
# nv_library(example SRCS example.cu) # nv_library(example SRCS example.cu)
# #
# To specify that a library new_example.a depends on other libraies: # To specify that a library new_example.a depends on other libraies:
# #
# cc_library(new_example SRCS new_example.cc DEPS example) # cc_library(new_example SRCS new_example.cc DEPS example)
# #
# Static libraries can be composed of other static libraries: # Static libraries can be composed of other static libraries:
# #
# cc_library(composed DEPS dependent1 dependent2 dependent3) # cc_library(composed DEPS dependent1 dependent2 dependent3)
# #
# To build an executable binary file from some source files and # To build an executable binary file from some source files and
# dependent libraries: # dependent libraries:
# #
# cc_binary(example SRCS main.cc something.cc DEPS example1 example2) # cc_binary(example SRCS main.cc something.cc DEPS example1 example2)
# #
# To build an executable binary file using NVCC, use the nv_ prefixed # To build an executable binary file using NVCC, use the nv_ prefixed
# version: # version:
# #
# nv_binary(example SRCS main.cc something.cu DEPS example1 example2) # nv_binary(example SRCS main.cc something.cu DEPS example1 example2)
# #
# To build a unit test binary, which is an executable binary with # To build a unit test binary, which is an executable binary with
# GoogleTest linked: # GoogleTest linked:
# #
# cc_test(example_test SRCS example_test.cc DEPS example) # cc_test(example_test SRCS example_test.cc DEPS example)
# #
# To build a unit test binary using NVCC, use the nv_ prefixed version: # To build a unit test binary using NVCC, use the nv_ prefixed version:
# #
# nv_test(example_test SRCS example_test.cu DEPS example) # nv_test(example_test SRCS example_test.cu DEPS example)
# #
# It is pretty often that executable and test binaries depend on # It is pretty often that executable and test binaries depend on
...@@ -162,6 +162,7 @@ function(cc_library TARGET_NAME) ...@@ -162,6 +162,7 @@ function(cc_library TARGET_NAME)
endif() endif()
if (cc_library_DEPS) if (cc_library_DEPS)
add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
endif() endif()
else(cc_library_SRCS) else(cc_library_SRCS)
if (cc_library_DEPS) if (cc_library_DEPS)
...@@ -211,6 +212,7 @@ function(nv_library TARGET_NAME) ...@@ -211,6 +212,7 @@ function(nv_library TARGET_NAME)
endif() endif()
if (nv_library_DEPS) if (nv_library_DEPS)
add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) add_dependencies(${TARGET_NAME} ${nv_library_DEPS})
target_link_libraries(${TARGET_NAME} ${nv_library_DEPS})
endif() endif()
else(nv_library_SRCS) else(nv_library_SRCS)
if (nv_library_DEPS) if (nv_library_DEPS)
...@@ -278,14 +280,16 @@ function(go_library TARGET_NAME) ...@@ -278,14 +280,16 @@ function(go_library TARGET_NAME)
set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}")
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${${TARGET_NAME}_LIB_PATH}" COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Golang build source code # Golang build source code
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${${TARGET_NAME}_LIB_PATH}" -o "${${TARGET_NAME}_LIB_PATH}"
${GO_SOURCE} "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) # must run under GOPATH
add_dependencies(${TARGET_NAME} go_path) WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_dependencies(${TARGET_NAME} go_vendor)
endfunction(go_library) endfunction(go_library)
function(go_binary TARGET_NAME) function(go_binary TARGET_NAME)
...@@ -293,12 +297,15 @@ function(go_binary TARGET_NAME) ...@@ -293,12 +297,15 @@ function(go_binary TARGET_NAME)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_library_SRCS} "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_binary_DEPS}) # TODO: don't know what ${TARGET_NAME}_link does
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin) install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary) endfunction(go_binary)
...@@ -318,10 +325,10 @@ endfunction(go_test) ...@@ -318,10 +325,10 @@ endfunction(go_test)
function(proto_library TARGET_NAME) function(proto_library TARGET_NAME)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(proto_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(proto_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(proto_srcs) set(proto_srcs)
set(proto_hdrs) set(proto_hdrs)
protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS}) protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS})
cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS protobuf) cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS ${proto_library_DEPS} protobuf)
endfunction() endfunction()
## Interaction between C++ and Python
Users employ API in Python to describe their own network, however, the network construction actually happens in C++. so Protobuf is introduced to send the message between Python and C++.
The Interaction between Python and C++ can be simplified as two steps:
1. C++ tells Python how many Ops there are, and what parameter do users need to offer to initialize a new Op. Python then builds API for each Op at compile time.
2. Users invoke APIs built by Python and provide necessary parameters. These parameters will be sent to C++ fo finish Op construction task.
### Message form C++ to Python
We define a Protobuf message class `OpProto` to hold message needed in the first step. What should an `OpProto` contain? This question is equivalent to “What message do we need to offer, to build a Python API which is legal and user oriented and can use to describe a whole Op.”
Following message are necessary:
1. Op's name, and its simple comment.
2. Input and output variable number; each variable's name, type, and comment.
3. Op's attributes; each attribute includes name, type, comment, **default value** and **value range**.
So `OpProto` can be defined as follows:
```proto
enum AttrType {
INT = 1;
FLOAT = 2;
STRING = 3;
INTS = 4;
FLOATS = 5;
STRINGS = 6;
};
message AttrValue {
AttrType type = 1;
optional int iv = 2;
optional float fv = 3;
optional string sv = 4;
repeated int ivs = 5;
repeated float fvs = 6;
repeated string svs = 7;
};
message AttrProto {
required string name = 1;
required string comment = 2;
required AttrType type = 3;
};
message VarProto {
required string name = 1;
required string comment = 2;
};
message OpProto {
repeated VarProto inputs = 1;
repeated VarProto outputs = 2;
repeated AttrProto attrs = 3;
required string type = 4;
required string comment = 5;
};
```
To generate Python code automatically:
```python
def create_python_ops_creatation_functions():
op_protos = paddle.framework.OpRegistry.get_all_op_proto()
for type_name in op_protos:
op_proto = op_protos[type_name]
def __impl__(**kwargs): # User must use key word args in Paddle API
inputs = [kwargs.get(ipt.name, "") for ipt in op_proto.inputs]
outputs = [kwargs.get(opt.name, "") for opt in op_proto.outputs]
attrs = [cast_to_op_attr(attr, kwargs.get(attr.name, None)) for attr in op_proto.attrs]
opdesc = input, outputs, type_name, attrs
return paddle.framework.OpRegistry.CreateOp(opdesc)
__impl__.__doc__ = create_doc_string(op_proto)
globals()[type_name] = __impl__
create_python_ops_creatation_functions()
```
### Message from Python to C++
To hold message needed in the above second step, we define Protobuf message class `OpDesc`. It is used to hold user-specified parameters in Op describing.
```proto
message OpDesc {
required string type = 1;
repeated string inputs = 2;
repeated string outputs = 3;
map<string, AttrValue> attrs = 4;
};
```
## OpProto Register
Every Op has its own `OpProto`. For using convenience, we need to register them and record all their messages. For each `Op` class, we define a corresponding `OpMaker` class, in whose constructor we implement the `OpProto`'s building process. `OpMaker`'s constructor will be invoked by another function `OpRegistry::RegisterOp()`.
```cpp
class OpProtoMaker {
public:
OpProtoMaker(OpProto* proto): proto_(proto) {}
protected:
OpProto* proto_;
void AddInput(const std::string& name, const std::string& desc) {...}
void AddAttr(const std::string& name, const std::string& desc, TypeId type) {...}
void AddComment(const std::string& comment) { ... }
};
class OpRegistry {
public:
using OpCreator = std::function<OperatorBase* (OpDesc& desc)>;
template <typename OpType, typename OpMaker>
static void RegisterOp(const std::string& name) {
gCreators_[name] = [](const OpDesc& desc) {
return new OpType(desc);
};
OpProto& opProto = gProtos_[name];
OpMaker()(&opProto);
}
static map<string, OpCreator> gCreators_;
static map<string, OpProto> gProtos_;
};
template <typename OpType, typename OpMaker>
class OpRegister {
public:
OpRegister(std::string type) {
OpRegistry::RegisterOp<OpType, OpMaker>(type);
}
};
#define REGISTER_OP(op_class, op_maker_class, type_name) \
class op_class##Register { \
private: \
const static OpRegister<#op_class, #op_maker_class> reg; \
}; \
const Register op_class##Register::reg(#type_name);
class CosineOp {
// ...
}
struct CosineOpProtoMaker : public OpProtoMaker {
CosineOpProtoMaker(OpProto* proto) : OpProtoMaker(proto) {
AddInput("input", "input of cosine op");
AddAttr("scale", "scale of cosine op", float).Default(1.0).LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
}
REGISTER_OP(CosineOp, CosineOpProtoMaker, cos);
```
In `REGISTER_OP(CosineOp, CosineOpProtoMaker, cos)`, we register not only `CosineOp` but also `CosineOpProto`. As fields of `CosineOpProto`, the default value and value range of `scale` are also registered here.
## Python API
Python APIs are divided into two types, high-level API and low-level API.
### High-Level API
High-level API is called by users directly, so it should keep its style consistent with existing V2 APIs.
Here is a sample about how a define a fc layer:
```python
hd = fc_layer(input=data, size=56, with_bias=True, activation="sigmoid");
```
`hd` is the output of `fc_layer` and it's a `variable`. It can be further sent into other layers as input.
The definition of `fc_layer()`:
```python
def fc_layer(input, size, with_bias, activation):
attr_map = {"size":size}
check_attrs(attr_map)
w = make_variable('w')
if with_bias:
b = make_variable('b')
else:
b = None
fc_output = make_variable('fc_output');
fc_op(input, w, b, fc_output, attr_map)
act_output = make_variable('sigmod_output');
if activation == "sigmod":
sigmod_op(fc_output, act_output);
elif:
# ...
return act_output;
```
### Low Leval API
In above sample, `fc_op` and `sigmod_op` are low-level API. They build `OpDesc` and invoke corresponding C++ code.
*TODO*
vendor/
.glide/
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master)
add_subdirectory(master/c)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
go_binary(master SRC master.go)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
go_binary(pserver SRCS pserver.go)
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
...@@ -29,11 +30,16 @@ func main() { ...@@ -29,11 +30,16 @@ func main() {
} }
log.SetLevel(level) log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout)) var idx int
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) if *index >= 0 {
idx, err := e.Register() idx = *index
if err != nil { } else {
panic(err) timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register()
if err != nil {
panic(err)
}
} }
s, err := pserver.NewService(idx) s, err := pserver.NewService(idx)
......
hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7
updated: 2017-06-27T14:05:48.925262819+08:00
imports:
- name: github.com/coreos/etcd
version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7
subpackages:
- auth/authpb
- clientv3
- clientv3/concurrency
- etcdserver/api/v3rpc/rpctypes
- etcdserver/etcdserverpb
- mvcc/mvccpb
- name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages:
- jsonpb
- proto
- name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129
- name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
- context
- http2
- http2/hpack
- idna
- internal/timeseries
- lex/httplex
- trace
- name: golang.org/x/sys
version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733
subpackages:
- unix
- name: golang.org/x/text
version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77
subpackages:
- secure/bidirule
- transform
- unicode/bidi
- unicode/norm
- name: google.golang.org/grpc
version: 8050b9cbc271307e5a716a9d782803d09b0d6f2d
subpackages:
- codes
- credentials
- grpclog
- internal
- keepalive
- metadata
- naming
- peer
- stats
- tap
- transport
testImports: []
package: github.com/PaddlePaddle/Paddle/go
import:
- package: github.com/PaddlePaddle/recordio
- package: github.com/coreos/etcd
version: ^3.2.1
subpackages:
- clientv3
- clientv3/concurrency
- package: github.com/namsral/flag
version: ^1.7.4-pre
- package: github.com/sirupsen/logrus
version: ^1.0.0
cmake_minimum_required(VERSION 3.0)
go_library(paddle_master SHARED) go_library(paddle_master SHARED)
...@@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
lock := concurrency.NewMutex(sess, lockPath) lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have // It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have // multiple master servers running (only configured to have
// one master running, but split-brain problem may cuase // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath) log.Debugf("Trying to acquire lock at %s.", lockPath)
...@@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error { ...@@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error {
// We lost the master lock and can not acquire // We lost the master lock and can not acquire
// it back, it means some other master is // it back, it means some other master is
// already started. We don't want cluster // already started. We don't want cluster
// managment system to kill the master server // management system to kill the master server
// who is holding the lock and running // who is holding the lock and running
// correctly. So the most feasible solution is // correctly. So the most feasible solution is
// to kill current master server. The current // to kill current master server. The current
......
go_library(paddle_pserver_cclient STATIC)
add_subdirectory(test)
#include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h"
typedef float real;
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
printf("test failed.\n");
exit(-1);
}
void print_parameter(paddle_gradient* param) {
if (param == NULL) {
printf("param is NULL!!\n");
} else {
printf("==== parameter ====\n");
printf("name: %s\n", param->name);
printf("content_len: %d\n", param->content_len);
printf("content_type: %d\n", param->element_type);
int i;
for (i = 0; i < param->content_len / (int)sizeof(real); ++i) {
printf("%f ", ((float*)param->content)[i]);
}
printf("\n\n");
}
}
int main() {
char addr[] = "localhost:3000";
paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
char* names[] = {"param_a", "param_b"};
retry:
printf("init parameter to pserver:\n");
real param_content1[] = {0.1, 0.2, 0.3};
real param_content2[] = {0.4, 0.5, 0.6};
paddle_parameter** params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * 2);
params[0] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
params[0]->name = names[0];
params[0]->content = (unsigned char*)param_content1;
params[0]->content_len = 3 * sizeof(real);
params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
params[1] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
params[1]->name = names[1];
params[1]->content = (unsigned char*)param_content2;
params[1]->content_len = 3 * sizeof(real);
params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
if (paddle_begin_init_params(c)) {
if (paddle_init_param(c, *params[0], NULL, 0) != 0) {
goto retry;
}
if (paddle_init_param(c, *params[1], NULL, 0) != 0) {
goto retry;
}
if (paddle_finish_init_params(c) != 0) {
goto retry;
}
} else {
fail();
}
printf("get inited parameters from pserver:\n");
// get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, params, 2) != 0) {
fail();
}
print_parameter(params[0]);
print_parameter(params[1]);
printf("send gradient to pserver:\n");
real gradient_content1[] = {0.01, 0.02, 0.03};
real gradinet_content2[] = {0.04, 0.05, 0.06};
paddle_gradient** grads =
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[0]->name = names[0];
grads[0]->content = (unsigned char*)gradient_content1;
grads[0]->content_len = 3 * sizeof(real);
grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[1]->name = names[1];
grads[1]->content = (unsigned char*)gradinet_content2;
grads[1]->content_len = 3 * sizeof(real);
grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
printf("print gradient sent to pserver:\n");
print_parameter(grads[0]);
print_parameter(grads[1]);
if (paddle_send_grads(c, grads, 2) != 0) {
fail();
}
printf("get updated parameters from pserver:\n");
// get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, params, 2) != 0) {
fail();
}
print_parameter(params[0]);
print_parameter(params[1]);
if (paddle_save_model(c, "/tmp/") != 0) {
fail();
}
return 0;
}
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING)
add_subdirectory(test)
endif()
...@@ -30,15 +30,16 @@ import ( ...@@ -30,15 +30,16 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client) var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client var curHandle C.paddle_pserver_client
func add(c *pserver.Client) C.paddle_pserver_client { func add(c *client.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle client := curHandle
...@@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client { ...@@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client {
return client return client
} }
func get(client C.paddle_pserver_client) *pserver.Client { func get(client C.paddle_pserver_client) *client.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
return handleMap[client] return handleMap[client]
} }
func remove(client C.paddle_pserver_client) *pserver.Client { func remove(client C.paddle_pserver_client) *client.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
h := handleMap[client] h := handleMap[client]
...@@ -80,9 +81,9 @@ func (s selector) Select() bool { ...@@ -80,9 +81,9 @@ func (s selector) Select() bool {
return bool(s) return bool(s)
} }
type lister []pserver.Server type lister []client.Server
func (l lister) List() []pserver.Server { func (l lister) List() []client.Server {
return l return l
} }
...@@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server { ...@@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server {
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client { func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs) a := C.GoString(addrs)
as := strings.Split(a, ",") as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as)) servers := make([]client.Server, len(as))
for i := range as { for i := range as {
servers[i].Index = i servers[i].Index = i
servers[i].Addr = as[i] servers[i].Addr = as[i]
} }
c := pserver.NewClient(lister(servers), len(as), selector(selected != 0)) c := client.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c) return add(c)
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client { func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd. // TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
panic("not implemented.") addr := C.GoString(etcd_endpoints)
etcd_client := client.NewEtcd(addr)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0))
return add(c)
} }
//export paddle_pserver_client_release //export paddle_pserver_client_release
......
cc_binary(main SRCS main.c DEPS paddle_pserver_cclient)
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)
add_style_check_target(test_cclient test_cclient.c)
...@@ -16,7 +16,7 @@ void sendGrads(paddle_pserver_client c) { ...@@ -16,7 +16,7 @@ void sendGrads(paddle_pserver_client c) {
"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000}; "param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000};
paddle_gradient grad2 = { paddle_gradient grad2 = {
"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000}; "param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000};
paddle_gradient* grads[2] = {&grad1, &grad2}; paddle_gradient *grads[2] = {&grad1, &grad2};
if (paddle_send_grads(c, grads, 2)) { if (paddle_send_grads(c, grads, 2)) {
fail(); fail();
} }
...@@ -39,7 +39,7 @@ void getParams(paddle_pserver_client c) { ...@@ -39,7 +39,7 @@ void getParams(paddle_pserver_client c) {
param_b.content = content_b; param_b.content = content_b;
param_b.content_len = 3000; param_b.content_len = 3000;
paddle_parameter* params[2] = {&param_a, &param_b}; paddle_parameter *params[2] = {&param_a, &param_b};
if (paddle_get_params(c, params, 2)) { if (paddle_get_params(c, params, 2)) {
fail(); fail();
} }
...@@ -48,6 +48,17 @@ void getParams(paddle_pserver_client c) { ...@@ -48,6 +48,17 @@ void getParams(paddle_pserver_client c) {
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
paddle_pserver_client c = paddle_new_pserver_client(addr, 1); paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
char *config_proto;
size_t config_proto_len = 0;
ssize_t nread;
FILE *fp = fopen("testdata/optimizer.pb", "r");
if (!fp) {
fail();
}
while ((nread = getline(&config_proto, &config_proto_len, fp)) != -1) {
printf("%s", config_proto);
}
fclose(fp);
retry: retry:
if (paddle_begin_init_params(c)) { if (paddle_begin_init_params(c)) {
paddle_parameter param; paddle_parameter param;
...@@ -59,7 +70,8 @@ retry: ...@@ -59,7 +70,8 @@ retry:
param.name = name_a; param.name = name_a;
param.content = content_a; param.content = content_a;
param.content_len = 2000; param.content_len = 2000;
int error = paddle_init_param(c, param, NULL, 0); int error =
paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) { if (error != 0) {
goto retry; goto retry;
} }
...@@ -68,7 +80,7 @@ retry: ...@@ -68,7 +80,7 @@ retry:
param.name = name_b; param.name = name_b;
param.content = content_b; param.content = content_b;
param.content_len = 3000; param.content_len = 3000;
error = paddle_init_param(c, param, NULL, 0); error = paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) { if (error != 0) {
goto retry; goto retry;
} }
......
...@@ -22,6 +22,8 @@ def main(): ...@@ -22,6 +22,8 @@ def main():
# create optimizer # create optimizer
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
......
package pserver package client
import ( import (
"errors" "errors"
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool { ...@@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool {
} }
// InitParam initializes the parameter on parameter servers. // InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error {
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil) return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
} }
...@@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error { ...@@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating // SendGrads sends gradients to parameter servers for updating
// parameters. // parameters.
func (c *Client) SendGrads(grads []Gradient) error { func (c *Client) SendGrads(grads []pserver.Gradient) error {
if len(grads) == 0 { if len(grads) == 0 {
return errors.New("no gradient received") return errors.New("no gradient received")
} }
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g pserver.Gradient) {
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil) err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
errCh <- err errCh <- err
}(g) }(g)
...@@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error { ...@@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
type result struct { type result struct {
idx int idx int
param Parameter param pserver.Parameter
err error err error
} }
...@@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) { ...@@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) {
} }
// GetParams gets parameters from parameter servers. // GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) { func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
rCh := make(chan result, len(names)) rCh := make(chan result, len(names))
for idx, name := range names { for idx, name := range names {
go func(name string, idx int) { go func(name string, idx int) {
var parameter Parameter var parameter pserver.Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter) err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err} rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx) }(name, idx)
...@@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) { ...@@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) {
} }
sort.Sort(rs) sort.Sort(rs)
ps := make([]Parameter, len(rs)) ps := make([]pserver.Parameter, len(rs))
for i := range rs { for i := range rs {
ps[i] = rs[i].param ps[i] = rs[i].param
} }
......
package pserver_test package client_test
import ( import (
"context"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
) )
const numPserver = 10 const (
numPserver = 10
etcdEndpoints = "127.0.0.1:2379"
timeout = 2 * time.Second
)
var port [numPserver]int var pserverClientPorts [numPserver]int
func init() { // this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
var ports [numPserver]int
for i := 0; i < numPserver; i++ { for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
...@@ -27,7 +39,7 @@ func init() { ...@@ -27,7 +39,7 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
port[i] = p ports[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService(0) s, err := pserver.NewService(0)
...@@ -48,6 +60,31 @@ func init() { ...@@ -48,6 +60,31 @@ func init() {
} }
}(l) }(l)
} }
return ports
}
func initNativeClient() {
pserverClientPorts = initClient()
}
func initEtcdClient() {
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{etcdEndpoints},
DialTimeout: time.Second * time.Duration(1),
})
if err != nil {
log.Errorf("err %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath)
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
ports := initClient()
for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
}
cancel()
client.Close()
} }
type selector bool type selector bool
...@@ -56,36 +93,35 @@ func (s selector) Select() bool { ...@@ -56,36 +93,35 @@ func (s selector) Select() bool {
return bool(s) return bool(s)
} }
type lister []pserver.Server type lister []client.Server
func (l lister) List() []pserver.Server { func (l lister) List() []client.Server {
return l return l
} }
func TestClientFull(t *testing.T) { func ClientTest(t *testing.T, c *client.Client) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
selected := c.BeginInitParams() selected := c.BeginInitParams()
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
const numParameter = 100 const numParameter = 100
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
var p pserver.Parameter var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i) p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32 p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100) p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p}) err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
err := c.FinishInitParams() err = c.FinishInitParams()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -124,3 +160,21 @@ func TestClientFull(t *testing.T) { ...@@ -124,3 +160,21 @@ func TestClientFull(t *testing.T) {
} }
} }
} }
func TestNativeClient(t *testing.T) {
initNativeClient()
servers := make([]client.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1)
}
// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
ClientTest(t, c2)
}
package client
import (
"context"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
const (
DefaultEtcdTimeout time.Duration = 5 * time.Second
)
// EtcdClient is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type EtcdClient struct {
client *clientv3.Client
timeout time.Duration
endpoints []string
}
// Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int {
var psDesired int
for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("psDesired %s invalid %v", psDesired, err)
time.Sleep(p.timeout)
continue
}
log.Debugf("Get psDesired number: %d", psDesired)
break
}
return psDesired
}
// List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server {
psDesired := p.Desired()
servers := make([]Server, psDesired)
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout)
continue
}
psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address
if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout)
continue
}
log.Infof("got value (%s) for key: %s", psAddr, psKey)
servers[i].Index = i
servers[i].Addr = psAddr
}
break
}
return servers
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient {
ep := strings.Split(endpoints, ",")
var cli *clientv3.Client
var err error
for {
cli, err = clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: DefaultEtcdTimeout,
})
if err != nil {
log.Errorf("Init etcd connection failed: %v", err)
time.Sleep(DefaultEtcdTimeout)
continue
}
break
}
log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{
client: cli,
timeout: DefaultEtcdTimeout,
endpoints: ep,
}
return client
}
...@@ -13,6 +13,13 @@ import ( ...@@ -13,6 +13,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
// PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
)
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination. // tolerance, service registry and coordination.
type EtcdClient struct { type EtcdClient struct {
...@@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) {
// it at the same time. // it at the same time.
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers) _, err := e.initDesiredPservers(ctx, e.numPservers)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
...@@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) { ...@@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) {
return pserverIdx, nil return pserverIdx, nil
} }
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired) dsStr := c.Get(PsDesired)
if dsStr == "" { if dsStr == "" {
...@@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
ps := c.Get(psKey) ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debugf("got value (%s) for key: %s", ps, psKey)
......
#include <stdlib.h>
#include "optimizer.h"
typedef int (*update_func)(void*, void*, paddle_element_type, const void*, int);
typedef void (*release_func)(void*);
typedef struct paddle_optimizer {
update_func update;
release_func release;
void* optimizer;
} paddle_optimizer;
void paddle_release_optimizer(paddle_optimizer* o) {
o->release(o->optimizer);
free(o);
}
int paddle_update_parameter(paddle_optimizer* o,
void* buffer,
paddle_element_type element_type,
const void* gradient,
int num_bytes) {
return o->update(o->optimizer, buffer, element_type, gradient, num_bytes);
}
typedef struct { double learning_rate; } SGD_optimizer;
int update_SGD(void* optimizer,
void* buffer,
paddle_element_type element_type,
const void* gradient,
int num_bytes) {
SGD_optimizer* o = (SGD_optimizer*)optimizer;
float* parameter = (float*)buffer;
float* grad = (float*)gradient;
int i;
for (i = 0; i < num_bytes / sizeof(float); ++i) {
parameter[i] -= o->learning_rate * grad[i];
}
return 0;
}
void release_SGD(void* optimizer) {
SGD_optimizer* o = (SGD_optimizer*)optimizer;
// nothing allocated on heap
}
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
SGD_optimizer* impl = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
impl->learning_rate = learning_rate;
paddle_optimizer* opt = (paddle_optimizer*)malloc(sizeof(paddle_optimizer));
opt->update = update_SGD;
opt->release = release_SGD;
opt->optimizer = impl;
return opt;
}
package pserver package pserver
/* // #cgo CFLAGS: -I ../../
#include "optimizer.h" // //FIXME: ldflags contain "build" path
*/ // #cgo LDFLAGS: ../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++
// #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h>
// #include <string.h>
import "C" import "C"
import ( import (
"fmt" "fmt"
"unsafe" "unsafe"
)
type optimizerType int
const ( log "github.com/sirupsen/logrus"
sgd optimizerType = iota
) )
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct { type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
elementType ElementType
} }
func newOptimizer(t optimizerType, learning_rate float64) *optimizer { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
o := &optimizer{} o := &optimizer{}
o.opt = C.paddle_create_SGD_optimizer(C.double(learning_rate)) o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param
c := paramWithConfigs.Config
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
"ConfigSize": len(c),
}).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
(*C.char)(nullPtr), 0)
return o return o
} }
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { func (o *optimizer) GetWeights() []byte {
if len(p.Content) != len(g.Content) { var buffer unsafe.Pointer
return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, len(p.Content), len(g.Content)) buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer)
} return cArrayToSlice(buffer, int(buffer_len)*C.sizeof_float)
}
if p.ElementType != g.ElementType { func (o *optimizer) UpdateParameter(g Gradient) error {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType) if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float)
if r != 0 { if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r) return fmt.Errorf("optimizer update returned error code: %d", r)
} }
......
#ifndef PADDLE_PSERVER_OPTIMIZER_H
#define PADDLE_PSERVER_OPTIMIZER_H
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
PADDLE_ELEMENT_TYPE_INT64 = 2,
PADDLE_ELEMENT_TYPE_UINT64 = 3,
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
struct paddle_optimizer;
struct paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate);
void paddle_release_optimizer(struct paddle_optimizer* o);
int paddle_update_parameter(struct paddle_optimizer* o,
void* buffer,
paddle_element_type element_type,
const void* gradient,
int num_bytes);
#endif /* PADDLE_PSERVER_OPTIMIZER_H */
package pserver package pserver
import "testing" import (
"io/ioutil"
"testing"
)
func TestSGDCreateRelease(t *testing.T) { func TestOptimizerCreateRelease(t *testing.T) {
o := newOptimizer(sgd, 1) p := Parameter{
Name: "a",
ElementType: Int32,
}
p.Content = []byte{1, 3}
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param)
o.Cleanup() o.Cleanup()
} }
...@@ -24,9 +24,6 @@ const ( ...@@ -24,9 +24,6 @@ const (
Float64 Float64
) )
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server. // Parameter is a piece of data to sync with the parameter server.
type Parameter struct { type Parameter struct {
Name string Name string
...@@ -48,9 +45,8 @@ type Service struct { ...@@ -48,9 +45,8 @@ type Service struct {
initialized chan struct{} initialized chan struct{}
idx int idx int
mu sync.Mutex mu sync.Mutex
opt *optimizer optMap map[string]*optimizer
paramMap map[string]Parameter
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
...@@ -58,9 +54,8 @@ type Service struct { ...@@ -58,9 +54,8 @@ type Service struct {
func NewService(idx int) (*Service, error) { func NewService(idx int) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
opt: newOptimizer(sgd, 0.005),
} }
s.paramMap = make(map[string]Parameter) s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
return s, nil return s, nil
} }
...@@ -81,7 +76,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -81,7 +76,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is // TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
return nil return nil
} }
...@@ -110,12 +105,12 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error { ...@@ -110,12 +105,12 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
p, ok := s.paramMap[g.Name] o, ok := s.optMap[g.Name]
if !ok { if !ok {
return fmt.Errorf("parameter: %s does not exist", g.Name) return fmt.Errorf("parameter: %s does not exist", g.Name)
} }
return s.opt.UpdateParameter(p, g) return o.UpdateParameter(g)
} }
// GetParam gets parameters from the parameter server. // GetParam gets parameters from the parameter server.
...@@ -124,7 +119,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -124,7 +119,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
p, ok := s.paramMap[name] opt, ok := s.optMap[name]
if !ok { if !ok {
return fmt.Errorf("parameter: %s does not exist", name) return fmt.Errorf("parameter: %s does not exist", name)
} }
...@@ -136,7 +131,9 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -136,7 +131,9 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// nature. This race condition is allowed deliberately // nature. This race condition is allowed deliberately
// to save the program from making a copy of the // to save the program from making a copy of the
// paramter content. // paramter content.
*parameter = p parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
return nil return nil
} }
......
package pserver_test package pserver_test
import ( import (
"io/ioutil"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
...@@ -9,7 +10,11 @@ import ( ...@@ -9,7 +10,11 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
func TestFull(t *testing.T) { const (
OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)
func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -18,7 +23,12 @@ func TestFull(t *testing.T) { ...@@ -18,7 +23,12 @@ func TestFull(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -27,7 +37,7 @@ func TestFull(t *testing.T) { ...@@ -27,7 +37,7 @@ func TestFull(t *testing.T) {
p1.Name = "param_b" p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -48,6 +58,7 @@ func TestFull(t *testing.T) { ...@@ -48,6 +58,7 @@ func TestFull(t *testing.T) {
} }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil) err = s.SendGrad(g1, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
...@@ -142,7 +153,12 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -142,7 +153,12 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
......
...@@ -5,3 +5,6 @@ nv_test(dim_test SRCS dim_test.cu DEPS ddim) ...@@ -5,3 +5,6 @@ nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc) cc_test(scope_test SRCS scope_test.cc)
cc_test(enforce_test SRCS enforce_test.cc) cc_test(enforce_test SRCS enforce_test.cc)
proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax="proto2";
package paddle.framework;
// Attribute Type for paddle's Op.
// Op contains many attributes. Each type of attributes could be different.
// The AttrType will be shared between AttrDesc and AttrProto.
enum AttrType {
INT = 0;
FLOAT = 1;
STRING = 2;
INTS = 3;
FLOATS = 4;
STRINGS = 5;
}
\ No newline at end of file
# Network Design
`Network` is the container and controller of a set of operators,
user can build a real network from a `NetDesc` which is a protobuf message
and use `Network.Run()` to run all the operators in the network.
A network object knows all Operators belonging to this network. Variables,
which are inputs and outputs of these operators,
are created and managed by a hierarchy of Scope objects.
# API
## Net
To make the `Network` extendable, a base class is defined like this
```c++
// operator's index stored in a network.
typedef int OpIndex;
// The minimum a network should be implemented.
class Net {
public:
// run all the operators and return success(true) or not, with all the
// variables are located in `scope`. `context` describes the detail execution
// environment for ops. `begin` and `end` specify the scope of `ops_` to run,
// If no positive indexes are provided, all operators in `ops_` will run.
virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1,
OpIndex end = -1) const = 0;
// Add an Operator according to `def`.
virtual OpIndex AddOp(const proto::OpDef &def) = 0;
// Add optimizer operators acctording to `attrs`.
virtual Error AddOptimizerOps(const OptAttrs &attrs) = 0;
// Add backward operators.
virtual Error AddBackwardOps() = 0;
// Infer the shapes of variables required by operators in the network. The
// `scope` will be mutated according to the inferred shapes.
static std::unique_ptr<Net> Create(const NetDesc &def = NetDesc());
};
```
All network implementations should build networks from a protobuf message which
describes the structure of a real network; `Run` method should be implemented by
all implementations to offer a universal method to forward or backward compute a network.
`Net::Create` is a method of factory pattern and can be implemented like
```c++
std::unique<Net> Net::Create(const NetDesc& def) {
switch (def.model_type()) {
case NN:
return new Network(def);
case Recursive:
return new RecursiveNet(def);
case Recurrent:
return new RecurrentNet(def);
}
return nullptr;
}
```
Network is designed as the container of operators. to make it more extendable,
we decouple it from the related variable resources.
`Run(Scope* scope)` takes the scope as a argument so that it can run in different scopes.
Finally, `Net` can be used as followed
```c++
Scope default_scope;
OpContext default_context;
auto net = Net::CreateNet(def);
if (net) {
net.Run(&default_scope, &default_context);
}
```
## `PlainNet` as a simple implementation of `BaseNet`
A very basic implementation is as follows. All it does is simply to run every operators in sequence.
```c++
class PlainNet : public Net {
public:
// Create a network describe by `def`. NetDesc is the definition of a network.
PlainNet(const NetDesc &def);
// Infer all the operators' input and output varialbes' shapes, will be called before every mini-batch
training.
virtual Error InferShape(Scope *scope) override;
// Run all the operators with the `scope`, if no scope is provided, default
// scope will be used instead. If no OpContext is provicded, default context will be used.
virtual Error Run(Scope *scope = nullptr, OpContext *context=nullptr, OpIndex begin = -1,
OpIndex end = -1) const override;
virtual OpIndex AddOp(const proto::OpDef &def) override;
virtual Error AddOptimizerOps(const OptAttrs &attrs) override;
virtual Error AddBackwardOps() override;
protected:
// Create operators accordding to `def`, will be called by the constructor.
Error BuildNet(const NetDesc &def);
// Add a operator which is identified as `type` and has attributes described
// in `attrs`, the `inputs` are the keys of readonly input variables,
// `outputs` are keys of mutable output variables. An `OpIndex` will be
// returned to indicate the offset of the new operator in `ops_`.
OpIndex AddOp(const std::string &type, const std::vector<string> &inputs,
const std::vector<string> &outputs,
const OprAttr &attrs = OprAttr());
private:
// the operators owned by `Network`.
std::vector<Operator> ops_;
};
```
`PlainNet` will create operators so that a private member `ops_` is defined,
the operators are created by `CreateNet`, and each operator is created by `AddOp`.
## PlainNet Usage
`PlainNet` can be used to define and run a network as follows
```c++
// create an empty scope located on CPU device.
Scope scope(CPUPlace());
// create and init variables described in `net_desc`.
scope.CreateVariables(net_desc);
scope.InitVariables(net_desc);
// create a network according to `net_desc`
auto net = Net::CreateNet(net_desc);
// Add more operators if needed.
net->AddOp(add...);
net->AddOp(fc...);
net->AddBackwardOps();
net->AddOptimizerOps();
// run the network providing the `scope`.
net.Run(&scope);
```
## `NetBuilder` as a C++ syntax wrapper
This is a detailed description of the user-related C++ network API, and may not needed in the prototype development stage.
The `NetBuilder` will give users a much simpler syntax as follows to create a network, and demonstrates how to use the `BaseNet`'s raw interfaces.
```c++
Variable* fc_out = builder.AddOp("fc", input=image, size=100, activation="Sigmoid");
Variable* prediction = builder.AddOp("fc", input=fc_out, size=10, activation="Sigmoid");
Variable* loss = builder.AddOp("cross_entropy", input=prediction, label=label);
Variable* avg_loss = builder.AddOp("mean", loss);
builder.BackwardFrom(avg_loss)
builder.AddOptimization(1e-4, "adam");
builder.Run();
```
`NetBuilder` will call `Net` 's virtual functions to change the real network structure, here is a sample definition
```c++
class NetBuilder final {
public:
NetBuilder(Net* net) : net_(net) {}
Variable* AddOp(const string& type, const vector<Variable>& inputs,
size_t size, Activation act) {
// much code here.
// ...
net_->AddOp(def);
need_rebuild_net_ = true;
net_->InferShape();
// ...
}
Error BackwardFrom(const Variable& cost);
Error Run(Scope* scope, OpContext* context, bool need_backward = true) {
// backward.
if (need_backward) {
if (need_rebuild_net_) {
AddBackwardOps();
AddOptimizerOps();
}
net_->Run(scope, context);
return;
}
// just forward.
net_->Run(scope, context, 0, last_forward_op_);
}
protected:
Error AddBackwardOps();
Error AddOptimizerOps();
private:
Net* net_;
OpIndex last_forward_op_{-1};
bool need_rebuild_net_{true};
}
```
## Compatibility with RNN
Benefitting from the decoupling of `PlainNet.Run` and `Scope`, `PlainNet` is compatible with future RNN design,
for example we can implement a simple recurrent neural network as follows
```c++
// copy some `vars` form `source` to `target`
void Copy(const Scope &source, Scope &target,
const std::vector<std::string> &vars);
Scope default_scope;
// some initial mutations on `default_scope` here.
auto rnn_step_net = PlainNet(rnn_step_net_def);
// Create rnn's states, the last scope is used to store rnn outputs.
Scope *rnn_states = new Scope[num_states + 1];
for (int i = 0; i < num_states + 1; i++) {
// Initialize all rnn state scopes, copy parameters and so on.
rnn_states[i].CreateVars(rnn_step_net_def);
Copy(default_scope, rnn_states[i], rnn_related_vars);
// Prepare rnn's inlinks, just copy inlink variables to each state.
Copy(default_scope, rnn_states[i], inlink_vars);
}
// Run the rnn.
for (int i = 0; i < num_states; i++) {
rnn_step_net.Run(rnn_states[i]);
// Copy current state's state variables to next state, the related variables
// are named like "previous_state_xxx".
Copy(rnn_states[i], rnn_states[i + 1], pre_state_vars)
}
// Copy rnn's final outputs to `default_scope`.
Copy(rnn_states[num_states], default_scope, outlink_vars);
```
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Protocol Message for 3rd-party language binding.
//
// Paddle Python package will use `OpProto` to generate op creation methods.
// The op creation methods take user's input and generate `OpDesc` proto message,
// then pass `OpDesc` to C++ side and create Op pointer.
//
syntax="proto2";
package paddle.framework;
import "attr_type.proto";
// Attribute protocol message for 3rd-party language binding.
// It will store the Op support what attribute and what type.
message AttrProto {
// Supported attribute name. e.g. `scale` for cosine op.
required string name = 1;
// Supported attribute type.
required AttrType type = 2;
// Supported attribute comments. It helps 3rd-party language generate doc-string.
required string comment = 3;
}
// Input or output message for 3rd-party language binding.
// It contains parameter name and its comments.
message VarProto {
// Input or output name in that op creation function.
// e.g. `cos(a, b, output, ...)`, "a", "b", "output" are names.
required string name = 1;
// The comment for that input. It helps 3rd-party language generate doc-string.
required string comment = 2;
}
// Op protocol message for 3rd-party language binding.
// It contains all information for generating op creation method.
message OpProto {
// The input information to generate op creation method.
repeated VarProto inputs = 1;
// The output information to generate op creation method.
repeated VarProto outputs = 2;
// The attribute information to generate op creation method.
repeated AttrProto attrs = 3;
// The comments for that Op. It helps 3rd-party language generate
// doc-string. The whole documentation of that Op is generated by comment,
// inputs, outputs, attrs together.
required string comment = 4;
// The type of that Op.
required string type = 5;
}
#include <gtest/gtest.h>
#include <paddle/framework/op_proto.pb.h>
TEST(TestOpProto, ALL) {
paddle::framework::OpProto proto;
{
auto ipt = proto.mutable_inputs()->Add();
*ipt->mutable_name() = "a";
*ipt->mutable_comment() = "the one input of cosine op";
}
{
auto ipt = proto.mutable_inputs()->Add();
*ipt->mutable_name() = "b";
*ipt->mutable_comment() = "the other input of cosine op";
}
{
auto opt = proto.mutable_outputs()->Add();
*opt->mutable_name() = "output";
*opt->mutable_comment() = "the output of cosine op";
}
{
auto attr = proto.mutable_attrs()->Add();
*attr->mutable_name() = "scale";
attr->set_type(paddle::framework::AttrType::FLOAT);
*attr->mutable_comment() = "the scale attribute of cosine op";
}
proto.set_type("cos");
*proto.mutable_comment() = "cosine op, output = scale * cos(a, b)";
ASSERT_TRUE(proto.IsInitialized());
}
\ No newline at end of file
...@@ -12,6 +12,7 @@ set(OPITMIZER_SRCS ...@@ -12,6 +12,7 @@ set(OPITMIZER_SRCS
add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies}) add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_simple_unittest(serialization_test) add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test) add_simple_unittest(parameter_optimizer_test)
......
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags) cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog)
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info gflags glog) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) nv_library(gpu_info SRCS gpu_info.cc DEPS gflags)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_library(dynamic_loader SRCS dynload/dynamic_loader.cc)
...@@ -28,5 +28,15 @@ size_t CpuMinChunkSize(); ...@@ -28,5 +28,15 @@ size_t CpuMinChunkSize();
//! Get the maximum chunk size for buddy allocator. //! Get the maximum chunk size for buddy allocator.
size_t CpuMaxChunkSize(); size_t CpuMaxChunkSize();
int GetCurrentDeviceId(void) {
int device_id;
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
return device_id;
}
void SetDeviceId(int device_id) {
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cublas_v2.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load cublas routine
* via operator overloading.
*
* note: default dynamic linked libs
*/
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, \
paddle::platform::dynload::GetCublasDsoHandle, \
&cublas_dso_handle); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
} \
} __name; // struct DynLoad__##__name
#else
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; // struct DynLoad__##__name
#endif
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
// include all needed cublas functions in HPPL
// clang-format off
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv) \
__macro(cublasDgemv) \
__macro(cublasSgemm) \
__macro(cublasDgemm) \
__macro(cublasSgeam) \
__macro(cublasDgeam) \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
#undef DYNAMIC_LOAD_CUBLAS_WRAP
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH
// clang-format on
#ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
#else
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cudnn.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using cudnn_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(cudnn_dso_flag, \
paddle::platform::dynload::GetCudnnDsoHandle, \
&cudnn_dso_handle); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
#else
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
#endif
/**
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
// clang-format off
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor) \
__macro(cudnnSetTensor4dDescriptorEx) \
__macro(cudnnGetConvolutionNdForwardOutputDim) \
__macro(cudnnGetConvolutionForwardAlgorithm) \
__macro(cudnnCreateTensorDescriptor) \
__macro(cudnnDestroyTensorDescriptor) \
__macro(cudnnCreateFilterDescriptor) \
__macro(cudnnSetFilter4dDescriptor) \
__macro(cudnnSetPooling2dDescriptor) \
__macro(cudnnDestroyFilterDescriptor) \
__macro(cudnnCreateConvolutionDescriptor) \
__macro(cudnnCreatePoolingDescriptor) \
__macro(cudnnDestroyPoolingDescriptor) \
__macro(cudnnSetConvolution2dDescriptor) \
__macro(cudnnDestroyConvolutionDescriptor) \
__macro(cudnnCreate) \
__macro(cudnnDestroy) \
__macro(cudnnSetStream) \
__macro(cudnnActivationForward) \
__macro(cudnnConvolutionForward) \
__macro(cudnnConvolutionBackwardBias) \
__macro(cudnnGetConvolutionForwardWorkspaceSize) \
__macro(cudnnTransformTensor) \
__macro(cudnnPoolingForward) \
__macro(cudnnPoolingBackward) \
__macro(cudnnSoftmaxBackward) \
__macro(cudnnSoftmaxForward) \
__macro(cudnnGetVersion) \
__macro(cudnnGetErrorString)
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
__macro(cudnnAddTensor) \
__macro(cudnnConvolutionBackwardData) \
__macro(cudnnConvolutionBackwardFilter)
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP)
// APIs available after R3:
#if CUDNN_VERSION >= 3000
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#endif
// APIs available after R4:
#if CUDNN_VERSION >= 4007
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
__macro(cudnnBatchNormalizationForwardTraining) \
__macro(cudnnBatchNormalizationForwardInference) \
__macro(cudnnBatchNormalizationBackward)
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
#endif
// APIs in R5
#if CUDNN_VERSION >= 5000
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
__macro(cudnnCreateActivationDescriptor) \
__macro(cudnnSetActivationDescriptor) \
__macro(cudnnGetActivationDescriptor) \
__macro(cudnnDestroyActivationDescriptor)
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5
#endif
#undef CUDNN_DNN_ROUTINE_EACH
// clang-format on
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <curand.h>
#include "paddle/platform/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
typedef curandStatus_t (*curandFunc)(Args...); \
std::call_once(curand_dso_flag, \
paddle::platform::dynload::GetCurandDsoHandle, \
&curand_dso_handle); \
void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
#else
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
#endif
/* include all needed curand functions in HPPL */
// clang-format off
#define CURAND_RAND_ROUTINE_EACH(__macro) \
__macro(curandCreateGenerator) \
__macro(curandSetStream) \
__macro(curandSetPseudoRandomGeneratorSeed)\
__macro(curandGenerateUniform) \
__macro(curandGenerateUniformDouble) \
__macro(curandDestroyGenerator)
// clang-format on
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)
#undef CURAND_RAND_ROUTINE_EACH
#undef DYNAMIC_LOAD_CURAND_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "dynamic_loader.h"
#include <dlfcn.h>
#include <memory>
#include <mutex>
#include <string>
#include "gflags/gflags.h"
#include "glog/logging.h"
DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, "
"/usr/local/cudnn/lib. If empty [default], dlopen "
"will search cudnn from LD_LIBRARY_PATH");
DEFINE_string(cuda_dir, "",
"Specify path for loading cuda library, such as libcublas, "
"libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH");
DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
namespace paddle {
namespace platform {
namespace dynload {
static inline std::string join(const std::string& part1,
const std::string& part2) {
// directory separator
const char sep = '/';
if (!part2.empty() && part2.front() == sep) {
return part2;
}
std::string ret;
ret.reserve(part1.size() + part2.size() + 1);
ret = part1;
if (!ret.empty() && ret.back() != sep) {
ret += sep;
}
ret += part2;
return ret;
}
static inline void GetDsoHandleFromDefaultPath(std::string& dso_path,
void** dso_handle,
int dynload_flags) {
VLOG(3) << "Try to find library: " << dso_path
<< " from default system path.";
// default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH
*dso_handle = dlopen(dso_path.c_str(), dynload_flags);
// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to
// bring System Integrity Projection (SIP), if dso_handle
// is null, search from default package path in Mac OS.
#if defined(__APPLE__) || defined(__OSX__)
if (nullptr == *dso_handle) {
dso_path = join("/usr/local/cuda/lib/", dso_path);
*dso_handle = dlopen(dso_path.c_str(), dynload_flags);
if (nullptr == *dso_handle) {
if (dso_path == "libcudnn.dylib") {
LOG(FATAL)
<< "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n" // NOLINT
<< "For instance, sudo tar -xzf "
"cudnn-7.5-osx-x64-v5.0-ga.tgz -C " // NOLINT
<< "/usr/local \n sudo chmod a+r "
"/usr/local/cuda/include/cudnn.h " // NOLINT
<< "/usr/local/cuda/lib/libcudnn*";
}
}
}
#endif
}
static inline void GetDsoHandleFromSearchPath(const std::string& search_root,
const std::string& dso_name,
void** dso_handle) {
int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
*dso_handle = nullptr;
std::string dlPath = dso_name;
if (search_root.empty()) {
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags);
} else {
// search xxx.so from custom path
dlPath = join(search_root, dso_name);
*dso_handle = dlopen(dlPath.c_str(), dynload_flags);
// if not found, search from default path
if (nullptr == *dso_handle) {
LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " ("
<< dlerror() << ")";
dlPath = dso_name;
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags);
}
}
CHECK(nullptr != *dso_handle) << "Failed to find dynamic library: " << dlPath
<< " (" << dlerror() << ") \n"
<< "Please specify its path correctly using "
"following ways: \n"
<< "Method. set environment variable "
"LD_LIBRARY_PATH on Linux or "
<< "DYLD_LIBRARY_PATH on Mac OS. \n"
<< "For instance, issue command: export "
"LD_LIBRARY_PATH=... \n"
<< "Note: After Mac OS 10.11, using the "
"DYLD_LIBRARY_PATH is impossible "
<< "unless System Integrity Protection (SIP) "
"is disabled.";
}
void GetCublasDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle);
#endif
}
void GetCudnnDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle);
#endif
}
void GetCurandDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle);
#endif
}
void GetWarpCTCDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle);
#endif
}
void GetLapackDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle);
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace platform {
namespace dynload {
/**
* @brief load the DSO of CUBLAS
*
* @param **dso_handle dso handler
*
*/
void GetCublasDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CUDNN
*
* @param **dso_handle dso handler
*
*/
void GetCudnnDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CURAND
*
* @param **dso_handle dso handler
*
*/
void GetCurandDsoHandle(void** dso_handle);
/**
* @brief load the DSO of warp-ctc
*
* @param **dso_handle dso handler
*
*/
void GetWarpCTCDsoHandle(void** dso_handle);
/**
* @brief load the DSO of lapack
*
* @param **dso_handle dso handler
*
*/
void GetLapackDsoHandle(void** dso_handle);
} // namespace dynload
} // namespace platform
} // namespace paddle
...@@ -5,6 +5,8 @@ import paddle.trainer_config_helpers.optimizers as v1_optimizers ...@@ -5,6 +5,8 @@ import paddle.trainer_config_helpers.optimizers as v1_optimizers
""" """
Optimizers(update equation) for SGD method. Optimizers(update equation) for SGD method.
TODO(zhihong) : create new optimizer with proto config, add new optimizer here
TODO(yuyang18): Complete comments. TODO(yuyang18): Complete comments.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册