提交 365b457a 编写于 作者: L liaogang

Merge conflicts

...@@ -19,3 +19,9 @@ third_party/ ...@@ -19,3 +19,9 @@ third_party/
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake
...@@ -4,6 +4,7 @@ cache: ...@@ -4,6 +4,7 @@ cache:
- $HOME/.ccache - $HOME/.ccache
- $HOME/.cache/pip - $HOME/.cache/pip
- $TRAVIS_BUILD_DIR/build/third_party - $TRAVIS_BUILD_DIR/build/third_party
- $TRAVIS_BUILD_DIR/build_android/third_party
sudo: required sudo: required
dist: trusty dist: trusty
os: os:
...@@ -11,6 +12,7 @@ os: ...@@ -11,6 +12,7 @@ os:
env: env:
- JOB=build_doc - JOB=build_doc
- JOB=check_style - JOB=check_style
- JOB=build_android
addons: addons:
apt: apt:
packages: packages:
......
...@@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(Boost QUIET) if(NOT ANDROID)
find_package(Boost QUIET)
endif()
include(simd) include(simd)
...@@ -97,6 +99,7 @@ include(external/swig) # download, build, install swig ...@@ -97,6 +99,7 @@ include(external/swig) # download, build, install swig
include(external/warpctc) # download, build, install warpctc include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any include(external/any) # download libn::any
include(external/eigen) # download eigen3 include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
include(configure) # add paddle env configuration include(configure) # add paddle env configuration
...@@ -139,6 +142,10 @@ endif(USE_NNPACK) ...@@ -139,6 +142,10 @@ endif(USE_NNPACK)
add_subdirectory(proto) add_subdirectory(proto)
# "add_subdirectory(go)" should be placed after the following loine,
# because it depends on paddle/optimizer.
add_subdirectory(paddle/optimizer)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be # "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it. # placed after this block, because they depends on it.
if(WITH_GOLANG) if(WITH_GOLANG)
...@@ -146,7 +153,9 @@ if(WITH_GOLANG) ...@@ -146,7 +153,9 @@ if(WITH_GOLANG)
endif(WITH_GOLANG) endif(WITH_GOLANG)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) if(WITH_PYTHON)
add_subdirectory(python)
endif()
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
...@@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
SET(CMAKE_SYSTEM_PROCESSOR armv7-a) SET(CMAKE_SYSTEM_PROCESSOR armv7-a)
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-") SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF() ENDIF()
...@@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}") STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}") STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}")
......
...@@ -52,6 +52,7 @@ ExternalProject_Add( ...@@ -52,6 +52,7 @@ ExternalProject_Add(
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog) ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog) LIST(APPEND external_project_dependencies glog)
...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND}) ...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND})
# arm_soft_fp_abi branch of OpenBLAS to support softfp # arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi # https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5") SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(TARGET "ARMV7")
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(TARGET "ARMV8")
ENDIF()
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI) ELSEIF(RPI)
# use hardfp # use hardfp
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")
......
INCLUDE(ExternalProject)
SET(PYBIND_SOURCE_DIR ${THIRD_PARTY_PATH}/pybind)
INCLUDE_DIRECTORIES(${PYBIND_SOURCE_DIR}/src/extern_pybind/include)
ExternalProject_Add(
extern_pybind
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/pybind/pybind11.git"
GIT_TAG "v2.1.1"
PREFIX ${PYBIND_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/pybind_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(pybind STATIC ${dummyfile})
else()
add_library(pybind INTERFACE)
endif()
add_dependencies(pybind extern_pybind)
LIST(APPEND external_project_dependencies pybind)
...@@ -18,6 +18,9 @@ INCLUDE(python_module) ...@@ -18,6 +18,9 @@ INCLUDE(python_module)
FIND_PACKAGE(PythonInterp 2.7) FIND_PACKAGE(PythonInterp 2.7)
IF(WITH_PYTHON) IF(WITH_PYTHON)
FIND_PACKAGE(PythonLibs 2.7) FIND_PACKAGE(PythonLibs 2.7)
# Fixme: Maybe find a static library. Get SHARED/STATIC by FIND_PACKAGE.
ADD_LIBRARY(python SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET python PROPERTY IMPORTED_LOCATION ${PYTHON_LIBRARIES})
ENDIF(WITH_PYTHON) ENDIF(WITH_PYTHON)
SET(py_env "") SET(py_env "")
......
...@@ -109,7 +109,9 @@ set(COMMON_FLAGS ...@@ -109,7 +109,9 @@ set(COMMON_FLAGS
-Wno-unused-function -Wno-unused-function
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=sign-compare -Wno-error=sign-compare
-Wno-error=unused-local-typedefs) -Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in Pybind11
)
set(GPU_COMMON_FLAGS set(GPU_COMMON_FLAGS
-fPIC -fPIC
......
...@@ -90,11 +90,11 @@ ...@@ -90,11 +90,11 @@
# including binary directory for generated headers. # including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE AND NOT ANDROID)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE) endif(NOT APPLE AND NOT ANDROID)
function(merge_static_libs TARGET_NAME) function(merge_static_libs TARGET_NAME)
set(libs ${ARGN}) set(libs ${ARGN})
...@@ -302,7 +302,7 @@ function(go_library TARGET_NAME) ...@@ -302,7 +302,7 @@ function(go_library TARGET_NAME)
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${${TARGET_NAME}_LIB_PATH}" COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Golang build source code # Golang build source code
...@@ -321,14 +321,11 @@ function(go_binary TARGET_NAME) ...@@ -321,14 +321,11 @@ function(go_binary TARGET_NAME)
cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
# TODO: don't know what ${TARGET_NAME}_link does
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin) install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary) endfunction(go_binary)
...@@ -336,15 +333,18 @@ endfunction(go_binary) ...@@ -336,15 +333,18 @@ endfunction(go_binary)
function(go_test TARGET_NAME) function(go_test TARGET_NAME)
set(options OPTIONAL) set(options OPTIONAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs DEPS)
cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_test_SRCS} ".${CMAKE_CURRENT_SOURCE_REL_DIR}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
function(proto_library TARGET_NAME) function(proto_library TARGET_NAME)
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
\frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x} \frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x}
假设 :math:`z = f(W^T x + b)` ,那么 假设 :math:`z = W^T x + b` ,那么
.. math:: .. math::
......
...@@ -37,7 +37,7 @@ Suppose our loss function is :math:`c(y)`, then ...@@ -37,7 +37,7 @@ Suppose our loss function is :math:`c(y)`, then
\frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x} \frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x}
Suppose :math:`z = f(W^T x + b)`, then Suppose :math:`z = W^T x + b`, then
.. math:: .. math::
......
...@@ -17,3 +17,7 @@ add_subdirectory(pserver/client/c) ...@@ -17,3 +17,7 @@ add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver) add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master) add_subdirectory(cmd/master)
add_subdirectory(master/c) add_subdirectory(master/c)
add_subdirectory(master)
add_subdirectory(pserver)
add_subdirectory(pserver/client)
add_subdirectory(utils/networkhelper)
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
go_binary(master SRC master.go DEPS paddle_go_optimizer) go_binary(master SRC master.go)
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
...@@ -18,53 +19,47 @@ func main() { ...@@ -18,53 +19,47 @@ func main() {
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
level, err := log.ParseLevel(*logLevel) level, err := log.ParseLevel(*logLevel)
if err != nil { candy.Must(err)
panic(err)
}
log.SetLevel(level) log.SetLevel(level)
var idx int var idx int
var cp pserver.Checkpoint var cp pserver.Checkpoint
var e *pserver.EtcdClient var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
timeout := time.Second * time.Duration((*etcdTimeout)) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout)
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register() idx, err = e.Register()
candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
if err != nil { if err != nil {
panic(err) log.Errorf("Fetch checkpoint failed, %s", err)
} }
} }
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil { candy.Must(err)
panic(err)
}
err = rpc.Register(s) err = rpc.Register(s)
if err != nil { candy.Must(err)
panic(err)
}
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { candy.Must(err)
panic(err)
}
log.Infof("start pserver at port %d", *port) log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil) err = http.Serve(l, nil)
candy.Must(err)
if err != nil {
panic(err)
}
} }
hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7 hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-06-27T14:05:48.925262819+08:00 updated: 2017-07-11T10:04:40.786745417+08:00
imports: imports:
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7 version: cb2a496c4ddd1c87a9f280e116649b599999ec79
subpackages: subpackages:
- auth/authpb - auth/authpb
- clientv3 - clientv3
...@@ -22,7 +22,9 @@ imports: ...@@ -22,7 +22,9 @@ imports:
- name: github.com/PaddlePaddle/recordio - name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129 version: edfb82af0739c84f241c87390ec5649c7b28c129
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: golang.org/x/net - name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2 version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages: subpackages:
...@@ -34,11 +36,11 @@ imports: ...@@ -34,11 +36,11 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733 version: abf9c25f54453410d0c6668e519582a9e1115027
subpackages: subpackages:
- unix - unix
- name: golang.org/x/text - name: golang.org/x/text
version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77 version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
subpackages: subpackages:
- secure/bidirule - secure/bidirule
- transform - transform
......
...@@ -10,3 +10,4 @@ import: ...@@ -10,3 +10,4 @@ import:
version: ^1.7.4-pre version: ^1.7.4-pre
- package: github.com/sirupsen/logrus - package: github.com/sirupsen/logrus
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy
if(WITH_TESTING)
go_test(master_test)
endif()
...@@ -68,7 +68,7 @@ func (c *Client) getRecords() { ...@@ -68,7 +68,7 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data // We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly // instance of the task is read. This is not exactly
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
c.taskFinished(t.ID) c.taskFinished(t.Meta.ID)
} }
} }
...@@ -118,6 +118,11 @@ func (c *Client) taskFinished(taskID int) error { ...@@ -118,6 +118,11 @@ func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil) return c.conn.Call("Service.TaskFinished", taskID, nil)
} }
// TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFailed", meta, nil)
}
// NextRecord returns next record in the dataset. // NextRecord returns next record in the dataset.
// //
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
......
...@@ -95,10 +95,16 @@ func TestGetFinishTask(t *testing.T) { ...@@ -95,10 +95,16 @@ func TestGetFinishTask(t *testing.T) {
t.Fatalf("Should get error, pass: %d\n", i) t.Fatalf("Should get error, pass: %d\n", i)
} }
err = c.taskFinished(tasks[0].ID) err = c.taskFinished(tasks[0].Meta.ID)
if err != nil { if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", err, i)
} }
err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
tasks = tasks[1:] tasks = tasks[1:]
task, err := c.getTask() task, err := c.getTask()
if err != nil { if err != nil {
...@@ -107,7 +113,7 @@ func TestGetFinishTask(t *testing.T) { ...@@ -107,7 +113,7 @@ func TestGetFinishTask(t *testing.T) {
tasks = append(tasks, task) tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.taskFinished(task.ID) err = c.taskFinished(task.Meta.ID)
if err != nil { if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", err, i)
} }
......
...@@ -31,30 +31,36 @@ type Chunk struct { ...@@ -31,30 +31,36 @@ type Chunk struct {
Index recordio.Index // chunk index Index recordio.Index // chunk index
} }
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}
// Task is the basic unit of data instances assigned to trainers. // Task is the basic unit of data instances assigned to trainers.
type Task struct { type Task struct {
ID int Meta TaskMeta
Chunks []Chunk Chunks []Chunk
} }
type taskEntry struct { type taskEntry struct {
Epoch int
NumTimeout int
Task Task Task Task
// A task fails if it's timeout or trainer reports it exits unnormally.
NumFailure int
} }
type taskQueues struct { type taskQueues struct {
Todo []taskEntry Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry Done []taskEntry
Failed []Task Failed []taskEntry
} }
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
timeoutMax int failureMax int
ready chan struct{} ready chan struct{}
store Store store Store
...@@ -73,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -73,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
var cur taskEntry var cur taskEntry
for i, c := range chunks { for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
id++ id++
result = append(result, cur) result = append(result, cur)
cur.Task.Chunks = nil cur.Task.Chunks = nil
...@@ -83,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -83,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
if len(cur.Task.Chunks) > 0 { if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
result = append(result, cur) result = append(result, cur)
} }
...@@ -91,11 +97,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -91,11 +97,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
s := &Service{} s := &Service{}
s.chunksPerTask = chunksPerTask s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax s.failureMax = failureMax
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{}) s.ready = make(chan struct{})
...@@ -257,19 +263,10 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -257,19 +263,10 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil return nil
} }
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { func (s *Service) processFailedTask(t taskEntry, epoch int) {
return func() { if t.Task.Meta.Epoch != epoch {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
if t.Epoch != epoch {
// new epoch, task launched after the // new epoch, task launched after the
// schedule of this timeout check. // schedule of this timeout check or failed status report.
return return
} }
...@@ -280,17 +277,31 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -280,17 +277,31 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
} }
}() }()
delete(s.taskQueues.Pending, t.Task.ID) delete(s.taskQueues.Pending, t.Task.Meta.ID)
t.NumTimeout++ t.NumFailure++
if t.NumTimeout > s.timeoutMax { if t.NumFailure > s.failureMax {
log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout) log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) s.taskQueues.Failed = append(s.taskQueues.Failed, t)
return return
} }
log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout) log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
s.processFailedTask(t, epoch)
} }
} }
...@@ -339,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -339,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error {
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
t.Epoch++ t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t s.taskQueues.Pending[t.Task.Meta.ID] = t
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
return err return err
} }
*task = t.Task *task = t.Task
log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID) log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil return nil
} }
...@@ -365,13 +376,12 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -365,13 +376,12 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.taskQueues.Pending[taskID] t, ok := s.taskQueues.Pending[taskID]
if !ok { if !ok {
err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err return nil
} }
// task finished, reset timeout // task finished, reset timeout
t.NumTimeout = 0 t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
...@@ -389,3 +399,22 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -389,3 +399,22 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
} }
return err return err
} }
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[meta.ID]
if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil
}
s.processFailedTask(t, meta.Epoch)
return nil
}
...@@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) { ...@@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 20)
for i := range ts { for i := range ts {
if ts[i].Task.ID != i { if ts[i].Task.Meta.ID != i {
t.Error(ts[i], i) t.Error(ts[i], i)
} }
} }
......
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer)
endif()
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()
libpaddle_go_optimizer.a
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m) target_link_libraries(paddle_go_optimizer stdc++ m)
# Copy library to the required place.
# See: go/pserver/optimizer.go:
# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
add_custom_command(TARGET paddle_go_optimizer POST_BUILD
COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}"
)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING) if(WITH_TESTING)
# FIXME: this test requires pserver which is not managed by the test # FIXME: this test requires pserver which is not managed by the test
......
...@@ -19,7 +19,7 @@ def main(): ...@@ -19,7 +19,7 @@ def main():
# create parameters # create parameters
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig #TODO(zhihong) : replace optimizer with new OptimizerConfig
......
...@@ -3,11 +3,13 @@ package client_test ...@@ -3,11 +3,13 @@ package client_test
import ( import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -42,7 +44,8 @@ func initClient() [numPserver]int { ...@@ -42,7 +44,8 @@ func initClient() [numPserver]int {
ports[i] = p ports[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -99,18 +102,22 @@ func (l lister) List() []client.Server { ...@@ -99,18 +102,22 @@ func (l lister) List() []client.Server {
return l return l
} }
func ClientTest(t *testing.T, c *client.Client) { func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams() selected := c.BeginInitParams()
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
const numParameter = 100 const numParameter = 1000
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb") config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
var wg sync.WaitGroup
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
wg.Add(1)
go func(i int) {
var p pserver.Parameter var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i) p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32 p.ElementType = pserver.Float32
...@@ -119,7 +126,10 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -119,7 +126,10 @@ func ClientTest(t *testing.T, c *client.Client) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wg.Done()
}(i)
} }
wg.Wait()
err = c.FinishInitParams() err = c.FinishInitParams()
if err != nil { if err != nil {
...@@ -127,7 +137,7 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -127,7 +137,7 @@ func ClientTest(t *testing.T, c *client.Client) {
} }
var grads []pserver.Gradient var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ { for i := 0; i < numParameter; i++ {
var g pserver.Gradient var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i) g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32 g.ElementType = pserver.Float32
...@@ -135,30 +145,67 @@ func ClientTest(t *testing.T, c *client.Client) { ...@@ -135,30 +145,67 @@ func ClientTest(t *testing.T, c *client.Client) {
grads = append(grads, g) grads = append(grads, g)
} }
err = c.SendGrads(grads) const paramPerGroup = 10
const numGroups = numParameter / paramPerGroup
// shuffle send grads order
for i := range grads {
j := rand.Intn(i + 1)
grads[i], grads[j] = grads[j], grads[i]
}
for i := 0; i < numGroups; i++ {
var gs []pserver.Gradient
if i == numGroups-1 {
gs = grads[i*paramPerGroup:]
} else {
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(gs []pserver.Gradient) {
err = c.SendGrads(gs)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wg.Done()
}(gs)
}
names := make([]string, numParameter) names := make([]string, numParameter)
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
names[i] = "p_" + strconv.Itoa(i) names[i] = "p_" + strconv.Itoa(i)
} }
params, err := c.GetParams(names) for i := 0; i < numGroups; i++ {
var ns []string
if i == numGroups-1 {
ns = names[i*paramPerGroup:]
} else {
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(ns []string) {
params, err := c.GetParams(ns)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(names) != len(params) { if len(ns) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
} }
for i := range params { for i := range params {
if names[i] != params[i].Name { if ns[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name) t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
}
} }
wg.Done()
}(ns)
} }
wg.Wait()
} }
func TestNativeClient(t *testing.T) { func TestNativeClient(t *testing.T) {
...@@ -168,13 +215,14 @@ func TestNativeClient(t *testing.T) { ...@@ -168,13 +215,14 @@ func TestNativeClient(t *testing.T) {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])} servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
} }
c1 := client.NewClient(lister(servers), len(servers), selector(true)) c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1) testClient(t, c1)
} }
// TODO: tmperary disable etcdClient test for dependency of etcd) // EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
func EtcdClient(t *testing.T) { func EtcdClient(t *testing.T) {
initEtcdClient() initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints) etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true)) c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2) testClient(t, c2)
} }
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
const ( const (
// PsDesired is etcd path for store desired pserver count // PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired" PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr // PsPath is the base dir for pserver to store their addr
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information // PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/" PsCheckpoint = "/checkpoints/"
...@@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil return idx, nil
} }
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key)
cancel()
if err != nil {
return []byte{}, err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return []byte{}, nil
}
v := kvs[0].Value
return v, nil
}
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value)) _, err := e.etcdClient.Put(ctx, key, string(value))
cancel() cancel()
if err != nil { if err != nil {
......
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #cgo LDFLAGS: ${SRCDIR}/../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #include "paddle/optimizer/optimizer.h" // #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>
...@@ -20,6 +19,7 @@ var nullPtr = unsafe.Pointer(uintptr(0)) ...@@ -20,6 +19,7 @@ var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct { type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
elementType ElementType elementType ElementType
contentLen int
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
...@@ -38,25 +38,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -38,25 +38,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer { func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{} o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType o.elementType = paramWithConfigs.Param.ElementType
o.contentLen = len(paramWithConfigs.Param.Content)
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State s := State
paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": len(p.Content), "ParamSize": paramBufferSize,
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s), "StateSize": len(s),
}).Info("New Optimizer Created with config:") }).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content))) cbuffer = C.malloc(paramBufferSize)
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), paramBufferSize)
var cstate unsafe.Pointer var cstate unsafe.Pointer
if len(s) != 0 { if len(s) != 0 {
cstate = unsafe.Pointer(&s[0]) cstate = unsafe.Pointer(&s[0])
} }
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s))) C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o return o
} }
...@@ -68,8 +71,8 @@ func (o *optimizer) GetWeights() []byte { ...@@ -68,8 +71,8 @@ func (o *optimizer) GetWeights() []byte {
func (o *optimizer) GetStates() []byte { func (o *optimizer) GetStates() []byte {
var cbuffer *C.char var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer) cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len)) return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
} }
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {
...@@ -77,7 +80,11 @@ func (o *optimizer) UpdateParameter(g Gradient) error { ...@@ -77,7 +80,11 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float) if o.contentLen != len(g.Content) {
return fmt.Errorf("Name: %s, parameter and gradient does not have same content len, parameter: %d, gradient: %d", g.Name, o.contentLen, len(g.Content))
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
if r != 0 { if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r) return fmt.Errorf("optimizer update returned error code: %d", r)
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
...@@ -21,14 +22,14 @@ import ( ...@@ -21,14 +22,14 @@ import (
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
// RPC error message.
const ( const (
// AlreadyInitialized is true if pserver is initialized
AlreadyInitialized = "pserver already initialized" AlreadyInitialized = "pserver already initialized"
// Uninitialized is true if pserver not fully initialized
Uninitialized = "pserver not fully initialized" Uninitialized = "pserver not fully initialized"
CheckpointMD5Failed = "checkpoint file MD5 validation failed"
) )
// Supported element types // Supported element types.
const ( const (
Int32 ElementType = iota Int32 ElementType = iota
UInt32 UInt32
...@@ -51,21 +52,15 @@ type ParameterWithConfig struct { ...@@ -51,21 +52,15 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
} }
// ParameterCheckpoint is Parameter and State checkpoint // checkpointMeta saves checkpoint metadata
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct { type checkpointMeta struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Md5sum string `json:"md5sum"` MD5 string `json:"md5"`
Timestamp string `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }
// Checkpoint is the pserver shard persist in file // Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter type Gradient Parameter
...@@ -81,12 +76,53 @@ type Service struct { ...@@ -81,12 +76,53 @@ type Service struct {
optMap map[string]*optimizer optMap map[string]*optimizer
} }
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
ParameterWithConfig
State []byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil {
return nil, err
}
var cpMeta checkpointMeta
if err = json.Unmarshal(v, &cpMeta); err != nil {
return nil, err
}
fn := filepath.Join(cpPath, cpMeta.UUID)
if _, err = os.Stat(fn); os.IsNotExist(err) {
return nil, err
}
content, err := ioutil.ReadFile(fn)
if err != nil {
return nil, err
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
return nil, errors.New(CheckpointMD5Failed)
}
dec := gob.NewDecoder(bytes.NewReader(content))
cp := Checkpoint{}
if err = dec.Decode(cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: time.Second * time.Duration(seconds), checkpointInterval: interval,
checkpointPath: path, checkpointPath: path,
client: client, client: client,
} }
...@@ -95,9 +131,11 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp ...@@ -95,9 +131,11 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp
if cp != nil { if cp != nil {
for _, item := range cp { for _, item := range cp {
p := item.ParamConfig p := ParameterWithConfig{
st := item.State Param: item.Param,
s.optMap[p.Param.Name] = newOptimizer(p, st) Config: item.Config,
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
} }
} }
return s, nil return s, nil
...@@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error { ...@@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
cp := make([]ParameterCheckpoint, 0, len(s.optMap)) cp := make([]parameterCheckpoint, len(s.optMap))
index := 0 index := 0
for name, opt := range s.optMap { for name, opt := range s.optMap {
var pc ParameterCheckpoint var pc parameterCheckpoint
pc.ParamConfig.Param.Name = name pc.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType pc.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights() pc.Param.Content = opt.GetWeights()
pc.State = opt.GetStates() pc.State = opt.GetStates()
cp[index] = pc cp[index] = pc
index++ index++
...@@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error { ...@@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error {
cpMeta := checkpointMeta{} cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String() cpMeta.Timestamp = time.Now().UnixNano()
h := md5.New() h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, _ := json.Marshal(cpMeta) cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
if err != nil { if err != nil {
return err return err
} }
...@@ -219,8 +257,12 @@ func (s *Service) doCheckpoint() error { ...@@ -219,8 +257,12 @@ func (s *Service) doCheckpoint() error {
log.Info("checkpoint does not exists.") log.Info("checkpoint does not exists.")
} else { } else {
err = os.Remove(cpMeta.UUID) err = os.Remove(cpMeta.UUID)
if err != nil {
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
} else {
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
} }
}
f, err := os.Create(cpMeta.UUID) f, err := os.Create(cpMeta.UUID)
defer f.Close() defer f.Close()
if err != nil { if err != nil {
......
...@@ -31,7 +31,7 @@ func TestServiceFull(t *testing.T) { ...@@ -31,7 +31,7 @@ func TestServiceFull(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var p1 pserver.Parameter var p1 pserver.Parameter
...@@ -40,40 +40,40 @@ func TestServiceFull(t *testing.T) { ...@@ -40,40 +40,40 @@ func TestServiceFull(t *testing.T) {
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var param pserver.Parameter var param pserver.Parameter
err = s.GetParam("param_b", &param) err = s.GetParam("param_b", &param)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
if !reflect.DeepEqual(param, p1) { if !reflect.DeepEqual(param, p1) {
t.FailNow() t.Fatal("not equal:", param, p1)
} }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil) err = s.SendGrad(g1, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.SendGrad(g2, nil) err = s.SendGrad(g2, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
var param1 pserver.Parameter var param1 pserver.Parameter
err = s.GetParam("param_a", &param1) err = s.GetParam("param_a", &param1)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
// don't compare content, since it's already changed by // don't compare content, since it's already changed by
...@@ -82,7 +82,7 @@ func TestServiceFull(t *testing.T) { ...@@ -82,7 +82,7 @@ func TestServiceFull(t *testing.T) {
p.Content = nil p.Content = nil
if !reflect.DeepEqual(param1, p) { if !reflect.DeepEqual(param1, p) {
t.FailNow() t.Fatal("not equal:", param1, p)
} }
} }
...@@ -90,16 +90,16 @@ func TestMultipleInit(t *testing.T) { ...@@ -90,16 +90,16 @@ func TestMultipleInit(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err.Error() != pserver.AlreadyInitialized { if err.Error() != pserver.AlreadyInitialized {
t.FailNow() t.Fatal(err)
} }
} }
...@@ -108,7 +108,7 @@ func TestUninitialized(t *testing.T) { ...@@ -108,7 +108,7 @@ func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.Fatal(err)
} }
} }
...@@ -154,12 +154,12 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -154,12 +154,12 @@ func TestBlockUntilInitialized(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.Fatal(err)
} }
wg.Wait() wg.Wait()
......
if(WITH_TESTING)
go_test(network_helper_test)
endif()
...@@ -8,13 +8,14 @@ add_subdirectory(gserver) ...@@ -8,13 +8,14 @@ add_subdirectory(gserver)
add_subdirectory(pserver) add_subdirectory(pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
add_subdirectory(optimizer)
add_subdirectory(string) add_subdirectory(string)
if(Boost_FOUND) if(Boost_FOUND)
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif() endif()
if(WITH_C_API) if(WITH_C_API)
......
...@@ -11,7 +11,10 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) ...@@ -11,7 +11,10 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc) cc_library(operator SRCS operator.cc DEPS op_desc device_context)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......
...@@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { ...@@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return ((0 <= idx.head) && (idx.head < size.head)); return ((0 <= idx.head) && (idx.head < size.head));
} }
/**
* \brief Check if a size and a stride create a Fortran order contiguous
* block of memory.
*/
template <int i>
HOST bool contiguous(const Dim<i>& size, const Dim<i>& stride, int mul = 1) {
if (product(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return (get<0>(stride) == contiguous_stride &&
contiguous(size.tail, stride.tail, mul * get<0>(size)));
}
///\cond HIDDEN
// Base case of contiguous, check the nth stride is the size of
// the prefix multiply of n-1 dims.
template <>
inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) {
if (get<0>(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return get<0>(stride) == contiguous_stride;
}
///\endcond
/** /**
* \brief Compute exclusive prefix-multiply of a Dim. * \brief Compute exclusive prefix-multiply of a Dim.
*/ */
...@@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { ...@@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
} }
///\endcond ///\endcond
/**
* \brief Calculate strides of a contiguous array of the given size
*
* Sets the stride for any dimension with an extent of 1 to 0.
* \param size Dim object containing the size of the array.
* \param base The base stride to use.
* \return Dim object the same size as \p size with the strides.
*/
template <int i>
HOSTDEVICE Dim<i> contiguous_strides(const Dim<i>& size, int base = 1) {
int stride = size.head == 1 ? 0 : base;
return Dim<i>(stride, contiguous_strides(size.tail, base * size.head));
}
///\cond HIDDEN
// Base case of contiguous_strides
template <>
HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) {
int stride = size.head == 1 ? 0 : base;
return Dim<1>(stride);
}
///\endcond
/** /**
* Add two dimensions together * Add two dimensions together
*/ */
......
...@@ -58,24 +58,6 @@ TEST(Dim, Equality) { ...@@ -58,24 +58,6 @@ TEST(Dim, Equality) {
EXPECT_EQ(paddle::framework::get<1>(c), 3); EXPECT_EQ(paddle::framework::get<1>(c), 3);
EXPECT_EQ(paddle::framework::get<2>(c), 12); EXPECT_EQ(paddle::framework::get<2>(c), 12);
// contiguous_strides
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 0);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 10);
EXPECT_EQ(paddle::framework::get<2>(c), 0);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 0);
EXPECT_EQ(paddle::framework::get<1>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 2);
EXPECT_EQ(paddle::framework::get<2>(c), 6);
// generate from an index // generate from an index
auto size = paddle::framework::make_dim(4, 5, 2); auto size = paddle::framework::make_dim(4, 5, 2);
c = paddle::framework::Dim<3>(14, size); c = paddle::framework::Dim<3>(14, size);
...@@ -101,16 +83,6 @@ TEST(Dim, Bool) { ...@@ -101,16 +83,6 @@ TEST(Dim, Bool) {
EXPECT_TRUE(a == a); EXPECT_TRUE(a == a);
EXPECT_FALSE(a == b); EXPECT_FALSE(a == b);
EXPECT_TRUE(a == c); EXPECT_TRUE(a == c);
// contiguous check
int x = 4, y = 5, z = 2;
paddle::framework::Dim<3> sizef(x, y, z);
paddle::framework::Dim<3> stridea(1, x, x*y);
paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y);
paddle::framework::Dim<3> stridec(1, x, 2*x*y);
EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea));
EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb));
EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec));
} }
TEST(Dim, Print) { TEST(Dim, Print) {
......
#include <paddle/framework/op_registry.h>
namespace paddle {
namespace framework {
template <>
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INT);
}
template <>
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOAT);
}
template <>
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRING);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INTS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOATS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS);
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
#pragma once #pragma once
#include <algorithm>
#include <type_traits>
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
//#include "paddle/framework/op_base.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
//==================For test================//
class OpBase {
public:
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attr_map_;
virtual std::string Run() const = 0;
virtual ~OpBase() {}
};
//=========================================//
// helper class to set attribute type // helper class to set attribute type
struct AttrTypeHelper { struct AttrTypeHelper {
template <typename T> template <typename T>
...@@ -64,36 +53,6 @@ struct AttrTypeHelper { ...@@ -64,36 +53,6 @@ struct AttrTypeHelper {
} }
}; };
template <>
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INT);
}
template <>
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOAT);
}
template <>
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRING);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INTS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOATS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS);
}
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
...@@ -103,28 +62,26 @@ class OpProtoAndCheckerMaker { ...@@ -103,28 +62,26 @@ class OpProtoAndCheckerMaker {
protected: protected:
void AddInput(const std::string& name, const std::string& comment) { void AddInput(const std::string& name, const std::string& comment) {
auto input = proto_->mutable_inputs()->Add(); auto input = proto_->mutable_inputs()->Add();
*(input->mutable_name()) = name; *input->mutable_name() = name;
*(input->mutable_comment()) = comment; *input->mutable_comment() = comment;
} }
void AddOutput(const std::string& name, const std::string& comment) { void AddOutput(const std::string& name, const std::string& comment) {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*(output->mutable_name()) = name; *output->mutable_name() = name;
*(output->mutable_comment()) = comment; *output->mutable_comment() = comment;
} }
template <typename T> template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name, TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment) { const std::string& comment) {
auto attr = proto_->mutable_attrs()->Add(); auto attr = proto_->mutable_attrs()->Add();
*(attr->mutable_name()) = name; *attr->mutable_name() = name;
*(attr->mutable_comment()) = comment; *attr->mutable_comment() = comment;
AttrTypeHelper::SetAttrType<T>(attr); AttrTypeHelper::SetAttrType<T>(attr);
return op_checker_->AddAttrChecker<T>(name); return op_checker_->AddAttrChecker<T>(name);
} }
void AddType(const std::string& op_type) { proto_->set_type(op_type); }
void AddComment(const std::string& comment) { void AddComment(const std::string& comment) {
*(proto_->mutable_comment()) = comment; *(proto_->mutable_comment()) = comment;
} }
...@@ -134,120 +91,127 @@ class OpProtoAndCheckerMaker { ...@@ -134,120 +91,127 @@ class OpProtoAndCheckerMaker {
}; };
class OpRegistry { class OpRegistry {
typedef std::function<OpBase*()> OpCreator; using OpCreator = std::function<OperatorBase*()>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
creators_[op_type] = []() { return new OpType; }; creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos_[op_type]; OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers_[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker); ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized() == true, *op_proto.mutable_type() = op_type;
"Fail to initialize %s's OpProto !", op_type); PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
} }
static OpBase* CreateOp(const OpDesc& op_desc) { static OperatorBase* CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type(); std::string op_type = op_desc.type();
OpBase* op = (creators_.at(op_type))(); OperatorBase* op = creators().at(op_type)();
(op->inputs_).resize(op_desc.inputs_size()); op->desc_ = op_desc;
for (int i = 0; i < op_desc.inputs_size(); ++i) { op->inputs_.reserve((size_t)op_desc.inputs_size());
(op->inputs_)[i] = op_desc.inputs(i); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
} std::back_inserter(op->inputs_));
(op->outputs_).resize(op_desc.outputs_size()); op->outputs_.reserve((size_t)op_desc.outputs_size());
for (int i = 0; i < op_desc.outputs_size(); ++i) { std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
(op->outputs_)[i] = op_desc.outputs(i); std::back_inserter(op->outputs_));
} for (auto& attr : op_desc.attrs()) {
for (int i = 0; i < op_desc.attrs_size(); ++i) { op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
const AttrDesc& ith_attr = op_desc.attrs(i); }
std::string name = ith_attr.name(); op_checkers().at(op_type).Check(op->attrs_);
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); op->Init();
}
const OpAttrChecker& op_checker = op_checkers_.at(op_type);
op_checker.Check(op->attr_map_);
return op; return op;
} }
private: private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_; static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<std::string, OpProto> protos_;
return protos_;
};
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
}; };
std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_;
std::unordered_map<std::string, OpProto> OpRegistry::protos_;
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_;
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper { class OpRegisterHelper {
public: public:
OpRegisterHelper(std::string op_type) { OpRegisterHelper(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
} }
}; };
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
class __op_class##Register { \ struct __test_global_namespace_##uniq_name##__ {}; \
private: \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ __test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
"REGISTER_OP must be in global namespace"); \
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##GPU_OR_CPU##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \ }; \
const OpRegisterHelper<__op_class, __op_maker_class> \ static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
__op_class##Register::reg(#__op_type); int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; }
// Demos #define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)
class CosineOp : public OpBase {
public: #define REGISTER_OP_CPU_KERNEL(type, KernelType) \
virtual std::string Run() const { REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)
std::string msg = "CosineOp runs! scale = " +
std::to_string(boost::get<float>(attr_map_.at("scale"))); #define USE_OP_WITHOUT_KERNEL(op_type) \
return msg; STATIC_ASSERT_GLOBAL_NAMESPACE( \
} __use_op_without_kernel_##op_type, \
}; "USE_OP_WITHOUT_KERNEL must be in global namespace"); \
extern int __op_register_##op_type##_handle__(); \
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { static int __use_op_ptr_##op_type##_without_kernel__ \
public: __attribute__((unused)) = __op_register_##op_type##_handle__()
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { #define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
AddInput("input", "input of cosine op"); STATIC_ASSERT_GLOBAL_NAMESPACE( \
AddOutput("output", "output of cosine op"); __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
AddAttr<float>("scale", "scale of cosine op") "USE_OP_KERNEL must be in global namespace"); \
.SetDefault(1.0) extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
.LargerThan(0.0); static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
AddType("cos"); __attribute__((unused)) = \
AddComment("This is cos op"); __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
}
}; #ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) \
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU);
class MyTestOp : public OpBase {
public: #else
virtual std::string Run() const { #define USE_OP(op_type) \
std::string msg = USE_OP_WITHOUT_KERNEL(op_type); \
"MyTestOp runs! test_attr = " + USE_OP_KERNEL(op_type, CPU); \
std::to_string(boost::get<int>(attr_map_.at("test_attr"))); USE_OP_KERNEL(op_type, GPU)
return msg; #endif
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is cos op");
}
};
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
public:
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddComment("This is my_test op");
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
op_desc.add_inputs("aa"); op_desc.add_inputs("aa");
op_desc.add_outputs("bb"); op_desc.add_outputs("bb");
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.3); attr->set_f(scale);
paddle::framework::OpBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run(); auto scope = std::make_shared<paddle::framework::Scope>();
std::string str = "CosineOp runs! scale = " + std::to_string(3.3); paddle::platform::CPUDeviceContext dev_ctx;
ASSERT_EQ(str.size(), debug_str.size()); op->Run(scope, dev_ctx);
for (size_t i = 0; i < debug_str.length(); ++i) { float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(debug_str[i], str[i]); ASSERT_EQ(scale_get, scale);
}
} }
TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, IllegalAttr) {
...@@ -35,7 +88,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -35,7 +88,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpBase* op __attribute__((unused)) = paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -54,15 +107,14 @@ TEST(OpRegistry, DefaultValue) { ...@@ -54,15 +107,14 @@ TEST(OpRegistry, DefaultValue) {
op_desc.add_inputs("aa"); op_desc.add_inputs("aa");
op_desc.add_outputs("bb"); op_desc.add_outputs("bb");
paddle::framework::OpBase* op = ASSERT_TRUE(op_desc.IsInitialized());
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run(); auto scope = std::make_shared<paddle::framework::Scope>();
float default_value = 1.0; paddle::platform::CPUDeviceContext dev_ctx;
std::string str = "CosineOp runs! scale = " + std::to_string(default_value); op->Run(scope, dev_ctx);
ASSERT_EQ(str.size(), debug_str.size()); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
} }
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
...@@ -74,7 +126,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -74,7 +126,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpBase* op __attribute__((unused)) = paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -93,7 +145,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -93,7 +145,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OpBase* op __attribute__((unused)) = paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -111,12 +163,11 @@ TEST(OpRegistry, CustomChecker) { ...@@ -111,12 +163,11 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
paddle::framework::OpBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run(); paddle::platform::CPUDeviceContext dev_ctx;
std::string str = "MyTestOp runs! test_attr = " + std::to_string(4); auto scope = std::make_shared<paddle::framework::Scope>();
ASSERT_EQ(str.size(), debug_str.size()); op->Run(scope, dev_ctx);
for (size_t i = 0; i < debug_str.length(); ++i) { int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(debug_str[i], str[i]); ASSERT_EQ(test_attr, 4);
}
} }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
ss << "type = " << desc_.type() << "\n";
ss << "inputs = [";
for (auto& ipt : inputs_) {
ss << ipt << ", ";
}
ss << "]\n";
ss << "outputs = [";
for (auto& opt : outputs_) {
ss << opt << ", ";
}
ss << "]\n";
ss << "attr_keys = [";
for (auto& attr : attrs_) {
ss << attr.first << ", ";
}
ss << "]\n";
return ss.str();
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace framework {
class OperatorBase;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
* should always construct a proto message OpDesc and call
* OpRegistry::CreateOp(op_desc) to get an Operator instance.
*/
class OperatorBase {
public:
virtual ~OperatorBase() {}
template <typename T>
inline const T& GetAttr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}
std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0;
protected:
std::string Type() const { return desc_.type(); }
public:
OpDesc desc_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attrs_;
};
class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
};
template <typename T>
struct VarToTensor {};
template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};
template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};
class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
platform::Place place_;
OpKernelKey() = default;
OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
};
struct OpKernelHash {
std::hash<bool> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
}
};
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
}
void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
};
private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}
protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const = 0;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
class OperatorTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
}
public:
float x = 0;
};
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op");
}
};
class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {}
};
class CPUKernelTest : public OpKernel {
public:
void Compute(const KernelContext& context) const {
float scale = context.op_.GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
std::cout << "this is cpu kernel" << std::endl;
std::cout << context.op_.DebugString() << std::endl;
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
TEST(OpKernel, all) {
using namespace paddle::framework;
OpDesc op_desc;
op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
delete op;
}
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <cstdint>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
...@@ -26,31 +27,65 @@ namespace framework { ...@@ -26,31 +27,65 @@ namespace framework {
class Tensor { class Tensor {
public: public:
Tensor() : offset_(0) {}
explicit Tensor(const DDim& dims) : dims_(dims), offset_(0) {}
template <typename T> template <typename T>
const T* data() const { const T* data() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(
"Tensor::data must be called after Tensor::mutable_data."); holder_ != nullptr,
return static_cast<const T*>(holder_->Ptr()); "Tenosr has not been initialized. Call Tensor::mutable_data first.");
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->Ptr()) + offset_);
} }
template <typename T, // must be POD types template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims, paddle::platform::Place place) { T* mutable_data(DDim dims, paddle::platform::Place place) {
dims_ = dims;
if (holder_ == nullptr || if (holder_ == nullptr ||
!(holder_->Place() == !(holder_->Place() ==
place) /* some versions of boost::variant don't have operator!= */ place) /* some versions of boost::variant don't have operator!= */
|| holder_->Size() < product(dims) * sizeof(T)) { || holder_->Size() < product(dims) * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, product(dims) * sizeof(T))); holder_.reset(new PlaceholderImpl<T>(place, product(dims) * sizeof(T)));
offset_ = 0;
} }
return static_cast<T*>(holder_->Ptr()); return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->Ptr()) +
offset_);
} }
template <typename T, // must be POD types void ShareDataFrom(const Tensor& src) {
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> PADDLE_ENFORCE(src.holder_ != nullptr,
T* mutable_data(DDim dims) { "Can not share data from an uninitialized tensor.");
return mutable_data<T>(dims, paddle::platform::get_place()); holder_ = src.holder_;
dims_ = src.dims_;
offset_ = src.offset_;
} }
Tensor Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE(holder_ != nullptr,
"The sliced tenosr has not been initialized.");
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
"Slice index is less than zero or out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
std::vector<int> d = vectorize(dims_);
int base = 1;
for (size_t i = 1; i < d.size(); ++i) {
base *= d[i];
}
Tensor dst;
dst.holder_ = holder_;
dst.dims_ = dims_;
dst.dims_[0] = end_idx - begin_idx;
dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize();
return dst;
}
DDim dims() const { return dims_; }
private: private:
// Placeholder hides type T, so it doesn't appear as a template // Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable. // parameter of Variable.
...@@ -59,6 +94,7 @@ class Tensor { ...@@ -59,6 +94,7 @@ class Tensor {
virtual void* Ptr() const = 0; virtual void* Ptr() const = 0;
virtual paddle::platform::Place Place() const = 0; virtual paddle::platform::Place Place() const = 0;
virtual size_t Size() const = 0; virtual size_t Size() const = 0;
virtual size_t TypeSize() const = 0;
}; };
template <typename T> template <typename T>
...@@ -85,6 +121,7 @@ class Tensor { ...@@ -85,6 +121,7 @@ class Tensor {
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); } virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t Size() const { return size_; } virtual size_t Size() const { return size_; }
virtual paddle::platform::Place Place() const { return place_; } virtual paddle::platform::Place Place() const { return place_; }
virtual size_t TypeSize() const { return sizeof(T); }
std::unique_ptr<T, Deleter> ptr_; std::unique_ptr<T, Deleter> ptr_;
paddle::platform::Place place_; // record the place of ptr_. paddle::platform::Place place_; // record the place of ptr_.
...@@ -92,6 +129,8 @@ class Tensor { ...@@ -92,6 +129,8 @@ class Tensor {
}; };
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated. std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_;
size_t offset_; // marks the begin of tensor data area.
}; };
} // namespace framework } // namespace framework
......
...@@ -15,15 +15,27 @@ ...@@ -15,15 +15,27 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <string>
TEST(Tensor, ASSERT) { TEST(Tensor, Dims) {
paddle::framework::Tensor cpu_tensor; using namespace paddle::framework;
using namespace paddle::platform;
Tensor tt(make_ddim({2, 3, 4}));
DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(i + 2, dims[i]);
}
}
TEST(Tensor, DataAssert) {
paddle::framework::Tensor src_tensor;
bool caught = false; bool caught = false;
try { try {
const double* p __attribute__((unused)) = cpu_tensor.data<double>(); src_tensor.data<double>();
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Tensor::data must be called after Tensor::mutable_data."; std::string msg =
"Tenosr has not been initialized. Call Tensor::mutable_data first.";
const char* what = err.what(); const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) { for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]); ASSERT_EQ(what[i], msg[i]);
...@@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) { ...@@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) {
ASSERT_TRUE(caught); ASSERT_TRUE(caught);
} }
/* mutable_data() is not tested at present /* following tests are not available at present
because Memory::Alloc() and Memory::Free() have not been ready. because Memory::Alloc() and Memory::Free() have not been ready.
TEST(Tensor, MutableData) { TEST(Tensor, MutableData) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
{ {
Tensor cpu_tensor; Tensor src_tensor;
float* p1 = nullptr; float* p1 = nullptr;
float* p2 = nullptr; float* p2 = nullptr;
// initialization // initialization
p1 = cpu_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace()); p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
EXPECT_NE(p1, nullptr); EXPECT_NE(p1, nullptr);
// set cpu_tensor a new dim with large size // set src_tensor a new dim with large size
// momery is supposed to be re-allocated // momery is supposed to be re-allocated
p2 = cpu_tensor.mutable_data<float>(make_ddim({3, 4})); p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace());
EXPECT_NE(p2, nullptr); EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2); EXPECT_NE(p1, p2);
// set cpu_tensor a new dim with same size // set src_tensor a new dim with same size
// momery block is supposed to be unchanged // momery block is supposed to be unchanged
p1 = cpu_tensor.mutable_data<float>(make_ddim({2, 2, 3})); p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
// set cpu_tensor a new dim with smaller size // set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged // momery block is supposed to be unchanged
p2 = cpu_tensor.mutable_data<float>(make_ddim({2, 2})); p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
} }
{ {
Tensor gpu_tensor; Tensor src_tensor;
float* p1 = nullptr; float* p1 = nullptr;
float* p2 = nullptr; float* p2 = nullptr;
// initialization // initialization
p1 = gpu_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace()); p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
EXPECT_NE(p1, nullptr); EXPECT_NE(p1, nullptr);
// set gpu_tensor a new dim with large size // set src_tensor a new dim with large size
// momery is supposed to be re-allocated // momery is supposed to be re-allocated
p2 = gpu_tensor.mutable_data<float>(make_ddim({3, 4})); p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), GPUPlace());
EXPECT_NE(p2, nullptr); EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2); EXPECT_NE(p1, p2);
// set gpu_tensor a new dim with same size // set src_tensor a new dim with same size
// momery block is supposed to be unchanged // momery block is supposed to be unchanged
p1 = gpu_tensor.mutable_data<float>(make_ddim({2, 2, 3})); p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), GPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
// set gpu_tensor a new dim with smaller size // set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged // momery block is supposed to be unchanged
p2 = gpu_tensor.mutable_data<float>(make_ddim({2, 2})); p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
} }
} }
TEST(Tensor, ShareDataFrom) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
dst_tensor.ShareDataFrom(src_tensor);
} catch (EnforceNotMet err) {
caught = true;
std::string msg = "Can not share data from an uninitialized tensor.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(caught);
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
{
Tensor src_tensor;
Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
}
TEST(Tensor, Slice) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
src_tensor.mutable_data<int>(make_ddim({5, 3, 4}), CPUPlace());
Tensor slice_tensor = src_tensor.Slice(1, 3);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
EXPECT_EQ(slice_dims[2], 4);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<int>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<int>(src_tensor.dims(), CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<int>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<int>(slice_tensor.dims(), CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
}
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<double>(src_tensor.dims(), GPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<double>(slice_tensor.dims(), GPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
}
}
*/ */
\ No newline at end of file
if(WITH_GPU)
nv_library(add_op SRCS add_op.cc add_op.cu DEPS operator op_registry glog ddim)
else()
cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim)
endif()
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
#include <paddle/framework/op_registry.h>
#include <paddle/framework/tensor.h>
#include <paddle/operators/add_op.h>
namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr,
"Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of Add Op's dimension must be same.");
// Need set dims in Tensor
// outputs[0]->set_dims(inputs[0]->dims())
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Two Element Add Operator.
The equation is: Out = X + Y
)DOC");
}
};
} // namespace op
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL(
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>);
\ No newline at end of file
#include <paddle/operators/add_op.h>
#include <paddle/framework/op_registry.h>
REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace>);
\ No newline at end of file
#pragma once
#include <glog/logging.h>
#include <paddle/framework/operator.h>
namespace paddle {
namespace operators {
template <typename Place>
class AddKernel : public framework::OpKernel {
public:
void Compute(const KernelContext &context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name();
}
};
} // namespace op
} // namespace paddle
#include <gtest/gtest.h>
#define private public
#include <paddle/framework/op_registry.h>
USE_OP(add_two);
TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two");
ASSERT_NE(it, protos.end());
}
\ No newline at end of file
...@@ -44,8 +44,8 @@ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, ...@@ -44,8 +44,8 @@ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
const int state_len) { const int state_len) {
paddle_optimizer* optimizer = new paddle_optimizer; paddle_optimizer* optimizer = new paddle_optimizer;
std::string config(config_proto, config_proto + config_proto_len); std::string config(config_proto, config_proto + config_proto_len);
Tensor* parameter = Tensor* parameter = new Tensor(reinterpret_cast<float*>(param_buffer),
new Tensor(reinterpret_cast<float*>(param_buffer), num_bytes); num_bytes / sizeof(float));
optimizer->impl = ParameterOptimizer::Create(config, parameter); optimizer->impl = ParameterOptimizer::Create(config, parameter);
if (state != nullptr) { if (state != nullptr) {
std::string s(state, state + state_len); std::string s(state, state + state_len);
...@@ -65,7 +65,8 @@ int paddle_update_parameter(paddle_optimizer* o, ...@@ -65,7 +65,8 @@ int paddle_update_parameter(paddle_optimizer* o,
int num_bytes) { int num_bytes) {
// TOOD(zhihong): datatype not work. need to add the runtime datatype // TOOD(zhihong): datatype not work. need to add the runtime datatype
auto grad_type = reinterpret_cast<const float*>(grad_buffer); auto grad_type = reinterpret_cast<const float*>(grad_buffer);
Tensor* gradient = new Tensor(const_cast<float*>(grad_type), num_bytes); Tensor* gradient =
new Tensor(const_cast<float*>(grad_type), num_bytes / sizeof(float));
o->impl->Update(gradient); o->impl->Update(gradient);
return PADDLE_SUCCESS; return PADDLE_SUCCESS;
} }
......
...@@ -6,6 +6,13 @@ nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) ...@@ -6,6 +6,13 @@ nv_library(gpu_info SRCS gpu_info.cc DEPS gflags)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_library(dynamic_loader SRCS dynload/dynamic_loader.cc DEPS gflags glog) add_subdirectory(dynload)
nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 gpu_info) IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
ELSE()
set(GPU_CTX_DEPS)
ENDIF()
cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
namespace paddle {
namespace platform {
template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return reinterpret_cast<CPUDeviceContext*>(this)->eigen_device();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return reinterpret_cast<CUDADeviceContext*>(this)->eigen_device();
}
#endif
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -23,8 +20,9 @@ limitations under the License. */ ...@@ -23,8 +20,9 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#include "paddle/platform/place.h" #include <paddle/platform/place.h>
#include "unsupported/Eigen/CXX11/Tensor" #include <memory>
#include <unsupported/Eigen/CXX11/Tensor>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -32,9 +30,29 @@ namespace platform { ...@@ -32,9 +30,29 @@ namespace platform {
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0;
template <typename DeviceType>
DeviceType* get_eigen_device();
}; };
class CPUDeviceContext : public DeviceContext {}; class CPUDeviceContext : public DeviceContext {
public:
Eigen::DefaultDevice* eigen_device() {
if (!eigen_device_) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
return eigen_device_.get();
}
Place GetPlace() const override {
Place retv = CPUPlace();
return retv;
}
private:
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
...@@ -58,8 +76,13 @@ class CUDADeviceContext : public DeviceContext { ...@@ -58,8 +76,13 @@ class CUDADeviceContext : public DeviceContext {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_), paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed"); "cudaStreamCreate failed");
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_ = new Eigen::GpuDevice(eigen_stream_); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
Place GetPlace() const override {
Place retv = GPUPlace();
return retv;
} }
void Wait() { void Wait() {
...@@ -69,7 +92,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -69,7 +92,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream() { return stream_; } cudaStream_t stream() { return stream_; }
Eigen::GpuDevice eigen_device() { return *eigen_device_; } Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); }
cublasHandle_t cublas_handle() { cublasHandle_t cublas_handle() {
if (!blas_handle_) { if (!blas_handle_) {
...@@ -134,10 +157,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -134,10 +157,8 @@ class CUDADeviceContext : public DeviceContext {
rand_generator_) == CURAND_STATUS_SUCCESS, rand_generator_) == CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed"); "curandDestroyGenerator failed");
} }
eigen_stream_.reset();
delete eigen_stream_; eigen_device_.reset();
delete eigen_device_;
paddle::platform::throw_on_error(cudaStreamDestroy(stream_), paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
"cudaStreamDestroy failed"); "cudaStreamDestroy failed");
} }
...@@ -146,8 +167,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -146,8 +167,8 @@ class CUDADeviceContext : public DeviceContext {
GPUPlace gpu_place_; GPUPlace gpu_place_;
cudaStream_t stream_; cudaStream_t stream_;
Eigen::CudaStreamDevice* eigen_stream_; std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
Eigen::GpuDevice* eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
cublasHandle_t blas_handle_{nullptr}; cublasHandle_t blas_handle_{nullptr};
...@@ -156,6 +177,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -156,6 +177,8 @@ class CUDADeviceContext : public DeviceContext {
int random_seed_; int random_seed_;
curandGenerator_t rand_generator_{nullptr}; curandGenerator_t rand_generator_{nullptr};
}; };
#endif #endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -15,13 +15,26 @@ limitations under the License. */ ...@@ -15,13 +15,26 @@ limitations under the License. */
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
TEST(CUDADeviceContext, Init) { using DEVICE_GPU = Eigen::GpuDevice;
TEST(Device, Init) {
int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) {
paddle::platform::DeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<DEVICE_GPU>();
ASSERT_NE(nullptr, gpu_device);
delete device_context;
}
}
TEST(Device, CUDADeviceContext) {
int count = paddle::platform::GetDeviceCount(); int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
paddle::platform::CUDADeviceContext* device_context = paddle::platform::CUDADeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i); new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device = device_context->eigen_device(); Eigen::GpuDevice* gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device.stream()); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle(); cublasHandle_t cublas_handle = device_context->cublas_handle();
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc)
#include <paddle/platform/dynload/cublas.h>
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
...@@ -23,8 +23,8 @@ namespace paddle { ...@@ -23,8 +23,8 @@ namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
std::once_flag cublas_dso_flag; extern std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr; extern void *cublas_dso_handle;
/** /**
* The following macro definition can generate structs * The following macro definition can generate structs
...@@ -34,10 +34,10 @@ void *cublas_dso_handle = nullptr; ...@@ -34,10 +34,10 @@ void *cublas_dso_handle = nullptr;
* note: default dynamic linked libs * note: default dynamic linked libs
*/ */
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
cublasStatus_t operator()(Args... args) { \ inline cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \ typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, \ std::call_once(cublas_dso_flag, \
paddle::platform::dynload::GetCublasDsoHandle, \ paddle::platform::dynload::GetCublasDsoHandle, \
...@@ -45,62 +45,46 @@ void *cublas_dso_handle = nullptr; ...@@ -45,62 +45,46 @@ void *cublas_dso_handle = nullptr;
void *p_##__name = dlsym(cublas_dso_handle, #__name); \ void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \ return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
} \ } \
} __name; // struct DynLoad__##__name }; \
extern DynLoad__##__name __name
#else #else
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ inline template <typename... Args> \
cublasStatus_t operator()(Args... args) { \ cublasStatus_t operator()(Args... args) { \
return __name(args...); \ return __name(args...); \
} \ } \
} __name; // struct DynLoad__##__name }; \
extern DynLoad__##__name __name
#endif #endif
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
// include all needed cublas functions in HPPL
// clang-format off
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv) \ __macro(cublasSgemv); \
__macro(cublasDgemv) \ __macro(cublasDgemv); \
__macro(cublasSgemm) \ __macro(cublasSgemm); \
__macro(cublasDgemm) \ __macro(cublasDgemm); \
__macro(cublasSgeam) \ __macro(cublasSgeam); \
__macro(cublasDgeam) \ __macro(cublasDgeam); \
__macro(cublasCreate_v2); \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate) __macro(cublasDestroy_v2); \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy) __macro(cublasSetStream_v2); \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream) __macro(cublasSetPointerMode_v2); \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode) __macro(cublasGetPointerMode_v2); \
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode) __macro(cublasSgemmBatched); \
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched) __macro(cublasDgemmBatched); \
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched) __macro(cublasCgemmBatched); \
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched) __macro(cublasZgemmBatched); \
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched) __macro(cublasSgetrfBatched); \
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched) __macro(cublasSgetriBatched); \
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched) __macro(cublasDgetrfBatched); \
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched) __macro(cublasDgetriBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
#undef DYNAMIC_LOAD_CUBLAS_WRAP CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP);
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH
// clang-format on #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
#ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
#else
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
#endif
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#include <paddle/platform/dynload/cudnn.h>
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUDNN_DNN_ROUTINE_EACH(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_R2(DEFINE_WRAP);
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_R5
CUDNN_DNN_ROUTINE_EACH_R5(DEFINE_WRAP);
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
\ No newline at end of file
...@@ -23,12 +23,12 @@ namespace paddle { ...@@ -23,12 +23,12 @@ namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
std::once_flag cudnn_dso_flag; extern std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr; extern void* cudnn_dso_handle;
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
...@@ -39,17 +39,19 @@ void* cudnn_dso_handle = nullptr; ...@@ -39,17 +39,19 @@ void* cudnn_dso_handle = nullptr;
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \ return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ }; \
extern struct DynLoad__##__name __name
#else #else
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \ return __name(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ }; \
extern DynLoad__##__name __name
#endif #endif
...@@ -57,80 +59,73 @@ void* cudnn_dso_handle = nullptr; ...@@ -57,80 +59,73 @@ void* cudnn_dso_handle = nullptr;
* include all needed cudnn functions in HPPL * include all needed cudnn functions in HPPL
* different cudnn version has different interfaces * different cudnn version has different interfaces
**/ **/
// clang-format off
#define CUDNN_DNN_ROUTINE_EACH(__macro) \ #define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor) \ __macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx) \ __macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnGetConvolutionNdForwardOutputDim) \ __macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm) \ __macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor) \ __macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor) \ __macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor) \ __macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor) \ __macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetPooling2dDescriptor) \ __macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnDestroyFilterDescriptor) \ __macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor) \ __macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor) \ __macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor) \ __macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor) \ __macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor) \ __macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnCreate) \ __macro(cudnnCreate); \
__macro(cudnnDestroy) \ __macro(cudnnDestroy); \
__macro(cudnnSetStream) \ __macro(cudnnSetStream); \
__macro(cudnnActivationForward) \ __macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward) \ __macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias) \ __macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize) \ __macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor) \ __macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward) \ __macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward) \ __macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward) \ __macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward) \ __macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion) \ __macro(cudnnGetVersion); \
__macro(cudnnGetErrorString) __macro(cudnnGetErrorString);
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
__macro(cudnnAddTensor) \ __macro(cudnnAddTensor); \
__macro(cudnnConvolutionBackwardData) \ __macro(cudnnConvolutionBackwardData); \
__macro(cudnnConvolutionBackwardFilter) __macro(cudnnConvolutionBackwardFilter);
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
// APIs available after R3: // APIs available after R3:
#if CUDNN_VERSION >= 3000 #if CUDNN_VERSION >= 3000
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \ __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize); \
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \ __macro(cudnnGetConvolutionBackwardDataAlgorithm); \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ __macro(cudnnGetConvolutionBackwardFilterAlgorithm); \
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize) __macro(cudnnGetConvolutionBackwardDataWorkspaceSize);
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#endif #endif
// APIs available after R4: // APIs available after R4:
#if CUDNN_VERSION >= 4007 #if CUDNN_VERSION >= 4007
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
__macro(cudnnBatchNormalizationForwardTraining) \ __macro(cudnnBatchNormalizationForwardTraining); \
__macro(cudnnBatchNormalizationForwardInference) \ __macro(cudnnBatchNormalizationForwardInference); \
__macro(cudnnBatchNormalizationBackward) __macro(cudnnBatchNormalizationBackward);
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
#endif #endif
// APIs in R5 // APIs in R5
#if CUDNN_VERSION >= 5000 #if CUDNN_VERSION >= 5000
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
__macro(cudnnCreateActivationDescriptor) \ __macro(cudnnCreateActivationDescriptor); \
__macro(cudnnSetActivationDescriptor) \ __macro(cudnnSetActivationDescriptor); \
__macro(cudnnGetActivationDescriptor) \ __macro(cudnnGetActivationDescriptor); \
__macro(cudnnDestroyActivationDescriptor) __macro(cudnnDestroyActivationDescriptor);
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5
#endif #endif
#undef CUDNN_DNN_ROUTINE_EACH
// clang-format on
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#include <paddle/platform/dynload/curand.h>
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag curand_dso_flag;
void *curand_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CURAND_RAND_ROUTINE_EACH(DEFINE_WRAP);
}
}
}
\ No newline at end of file
...@@ -22,10 +22,10 @@ limitations under the License. */ ...@@ -22,10 +22,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
std::once_flag curand_dso_flag; extern std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr; extern void *curand_dso_handle;
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
curandStatus_t operator()(Args... args) { \ curandStatus_t operator()(Args... args) { \
...@@ -36,32 +36,29 @@ void *curand_dso_handle = nullptr; ...@@ -36,32 +36,29 @@ void *curand_dso_handle = nullptr;
void *p_##__name = dlsym(curand_dso_handle, #__name); \ void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \ return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ }; \
extern DynLoad__##__name __name
#else #else
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
curandStatus_t operator()(Args... args) { \ curandStatus_t operator()(Args... args) { \
return __name(args...); \ return __name(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ }; \
extern DynLoad__##__name __name
#endif #endif
/* include all needed curand functions in HPPL */
// clang-format off
#define CURAND_RAND_ROUTINE_EACH(__macro) \ #define CURAND_RAND_ROUTINE_EACH(__macro) \
__macro(curandCreateGenerator) \ __macro(curandCreateGenerator); \
__macro(curandSetStream) \ __macro(curandSetStream); \
__macro(curandSetPseudoRandomGeneratorSeed)\ __macro(curandSetPseudoRandomGeneratorSeed); \
__macro(curandGenerateUniform) \ __macro(curandGenerateUniform); \
__macro(curandGenerateUniformDouble) \ __macro(curandGenerateUniformDouble); \
__macro(curandDestroyGenerator) __macro(curandDestroyGenerator);
// clang-format on
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) CURAND_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CURAND_WRAP);
#undef CURAND_RAND_ROUTINE_EACH
#undef DYNAMIC_LOAD_CURAND_WRAP
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include <paddle/framework/scope.h>
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace pd = paddle::framework;
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of Paddle Paddle");
py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle.
)DOC")
.def("is_int", [](const pd::Variable& var) { return var.IsType<int>(); })
.def("set_int",
[](pd::Variable& var, int val) -> void {
*var.GetMutable<int>() = val;
})
.def("get_int",
[](const pd::Variable& var) -> int { return var.Get<int>(); });
py::class_<pd::Scope, std::shared_ptr<pd::Scope>>(m, "Scope")
.def(py::init<const std::shared_ptr<pd::Scope>&>())
.def("get_var",
&pd::Scope::GetVariable,
py::return_value_policy::reference)
.def("create_var",
&pd::Scope::CreateVariable,
py::return_value_policy::reference);
return m.ptr();
}
#!/bin/bash
set -e
ANDROID_STANDALONE_TOOLCHAIN=$HOME/android-toolchain-gcc
TMP_DIR=$HOME/$JOB/tmp
mkdir -p $TMP_DIR
cd $TMP_DIR
wget -q https://dl.google.com/android/repository/android-ndk-r14b-linux-x86_64.zip
unzip -q android-ndk-r14b-linux-x86_64.zip
chmod +x $TMP_DIR/android-ndk-r14b/build/tools/make-standalone-toolchain.sh
$TMP_DIR/android-ndk-r14b/build/tools/make-standalone-toolchain.sh --force --arch=arm --platform=android-21 --install-dir=$ANDROID_STANDALONE_TOOLCHAIN
cd $HOME
rm -rf $TMP_DIR
# Create the build directory for CMake.
mkdir -p $TRAVIS_BUILD_DIR/build_android
cd $TRAVIS_BUILD_DIR/build_android
# Compile paddle binaries
cmake -DCMAKE_SYSTEM_NAME=Android \
-DANDROID_STANDALONE_TOOLCHAIN=$ANDROID_STANDALONE_TOOLCHAIN \
-DANDROID_ABI=armeabi-v7a \
-DANDROID_ARM_NEON=ON \
-DANDROID_ARM_MODE=ON \
-DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_STYLE_CHECK=OFF \
..
make -j `nproc`
...@@ -22,7 +22,8 @@ DECLARE_string(save_dir); ...@@ -22,7 +22,8 @@ DECLARE_string(save_dir);
namespace paddle { namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater( NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec) const OptimizationConfig &config, const std::string pserverSpec)
: parameterClient_(-1), : trainerConfig_(config),
parameterClient_(-1),
newParameters_(nullptr), newParameters_(nullptr),
newGradients_(nullptr), newGradients_(nullptr),
pserverSpec_(pserverSpec) {} pserverSpec_(pserverSpec) {}
...@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init( ...@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init(
LOG(INFO) << "paddle_begin_init_params start"; LOG(INFO) << "paddle_begin_init_params start";
for (int i = 0; i < parameterSize(); ++i) { for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig(); auto paramConfig = parameters_[i]->getConfig();
std::string bytes = paramConfig.SerializeAsString(); LOG(INFO) << "old param config: " << paramConfig.DebugString();
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
OptimizerConfig optimizeConfigV2;
auto sgdConfigV2 = optimizeConfigV2.mutable_sgd();
sgdConfigV2->set_momentum(paramConfig.momentum());
sgdConfigV2->set_decay(paramConfig.decay_rate());
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
auto constlr = optimizeConfigV2.mutable_const_lr();
constlr->set_learning_rate(paramConfig.learning_rate());
if (trainerConfig_.algorithm() == "sgd") {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
// FIXME: config all algorithms
} else {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
}
std::string bytes = optimizeConfigV2.SerializeAsString();
const char *array = bytes.data(); const char *array = bytes.data();
int size = (int)bytes.size(); int size = (int)bytes.size();
paddle_init_param( paddle_init_param(
...@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) { ...@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
void NewRemoteParameterUpdater::startPass() {} void NewRemoteParameterUpdater::startPass() {}
bool NewRemoteParameterUpdater::finishPass() { return true; } bool NewRemoteParameterUpdater::finishPass() { return true; }
} } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <thread> #include <thread>
#include "OptimizerConfig.pb.h"
#include "ParameterUpdater.h" #include "ParameterUpdater.h"
#include "libpaddle_pserver_cclient.h" #include "libpaddle_pserver_cclient.h"
#include "paddle/pserver/ParameterClient2.h" #include "paddle/pserver/ParameterClient2.h"
...@@ -101,6 +102,7 @@ private: ...@@ -101,6 +102,7 @@ private:
} }
protected: protected:
const OptimizationConfig& trainerConfig_;
/// internal parameter client object for exchanging data with pserver /// internal parameter client object for exchanging data with pserver
paddle_pserver_client parameterClient_; paddle_pserver_client parameterClient_;
/// the parameters for new pserver client /// the parameters for new pserver client
......
...@@ -26,10 +26,17 @@ endif(WITH_GOLANG) ...@@ -26,10 +26,17 @@ endif(WITH_GOLANG)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
${CMAKE_CURRENT_BINARY_DIR}/setup.py) ${CMAKE_CURRENT_BINARY_DIR}/setup.py)
add_custom_command(OUTPUT ${PROJ_ROOT}/python/paddle/v2/framework/core.so
COMMAND cmake -E copy $<TARGET_FILE:paddle_pybind> ${PROJ_ROOT}/python/paddle/v2/framework/core.so
DEPENDS paddle_pybind)
add_custom_target(copy_paddle_pybind ALL DEPENDS ${PROJ_ROOT}/python/paddle/v2/framework/core.so)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp
DEPENDS gen_proto_py framework_py_proto ${PY_FILES} ${external_project_dependencies} ${COPY_PADDLE_MASTER}) DEPENDS gen_proto_py copy_paddle_pybind framework_py_proto ${PY_FILES} ${external_project_dependencies} ${COPY_PADDLE_MASTER})
add_custom_target(paddle_python ALL DEPENDS add_custom_target(paddle_python ALL DEPENDS
${OUTPUT_DIR}/.timestamp) ${OUTPUT_DIR}/.timestamp)
......
"""
Default scope function.
`Paddle` manages Scope as programming language's scope. It just a
thread-local stack of Scope. Top of that stack is current scope, the bottom
of that stack is all scopes' parent.
Invoking `create_var/get_var` can `create/get` variable in current scope.
Invoking `enter_local_scope/leave_local_scope` can create or destroy local
scope.
A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
"""
import paddle.v2.framework.core
import threading
__tl_scope__ = threading.local()
__all__ = [
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var',
'get_var', 'scoped_function'
]
def get_cur_scope():
"""
Get current scope.
:rtype: paddle.v2.framework.core.Scope
"""
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None)
if cur_scope_stack is None:
__tl_scope__.cur_scope = list()
if len(__tl_scope__.cur_scope) == 0:
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None))
return __tl_scope__.cur_scope[-1]
def enter_local_scope():
"""
Enter a new local scope
"""
cur_scope = get_cur_scope()
new_scope = paddle.v2.framework.core.Scope(cur_scope)
__tl_scope__.cur_scope.append(new_scope)
def leave_local_scope():
"""
Leave local scope
"""
__tl_scope__.cur_scope.pop()
def create_var(name):
"""
create variable in current scope.
"""
return get_cur_scope().create_var(name)
def get_var(name):
"""
get variable in current scope.
"""
return get_cur_scope().get_var(name)
def scoped_function(func):
"""
invoke `func` in new scope.
:param func: a callable function that will be run in new scope.
:type func: callable
"""
enter_local_scope()
try:
func()
except:
raise
finally:
leave_local_scope()
add_python_test(test_framework test_protobuf.py) add_python_test(test_framework test_protobuf.py test_scope.py
test_default_scope_funcs.py)
from paddle.v2.framework.default_scope_funcs import *
import unittest
class TestDefaultScopeFuncs(unittest.TestCase):
def test_cur_scope(self):
self.assertIsNotNone(get_cur_scope())
def test_none_variable(self):
self.assertIsNone(get_var("test"))
def test_create_var_get_var(self):
var_a = create_var("var_a")
self.assertIsNotNone(var_a)
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
enter_local_scope()
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
leave_local_scope()
def test_var_get_int(self):
def __new_scope__():
i = create_var("var_i")
self.assertFalse(i.is_int())
i.set_int(10)
self.assertTrue(i.is_int())
self.assertEqual(10, i.get_int())
for _ in xrange(10):
scoped_function(__new_scope__)
if __name__ == '__main__':
unittest.main()
...@@ -24,3 +24,7 @@ class TestFrameworkProto(unittest.TestCase): ...@@ -24,3 +24,7 @@ class TestFrameworkProto(unittest.TestCase):
attr.type = attr_type_lib.FLOAT attr.type = attr_type_lib.FLOAT
op_proto.type = "cos" op_proto.type = "cos"
self.assertTrue(op_proto.IsInitialized()) self.assertTrue(op_proto.IsInitialized())
if __name__ == "__main__":
unittest.main()
import paddle.v2.framework.core
import unittest
class TestScope(unittest.TestCase):
def test_create_destroy(self):
paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None)
self.assertIsNotNone(scope)
scope_with_parent = paddle_c.Scope(scope)
self.assertIsNotNone(scope_with_parent)
def test_none_variable(self):
paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None)
self.assertIsNone(scope.get_var("test"))
def test_create_var_get_var(self):
paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None)
var_a = scope.create_var("var_a")
self.assertIsNotNone(var_a)
self.assertIsNotNone(scope.get_var('var_a'))
scope2 = paddle_c.Scope(scope)
self.assertIsNotNone(scope2.get_var('var_a'))
def test_var_get_int(self):
paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None)
var = scope.create_var("test_int")
var.set_int(10)
self.assertTrue(var.is_int())
self.assertEqual(10, var.get_int())
if __name__ == '__main__':
unittest.main()
...@@ -66,6 +66,8 @@ class Optimizer(object): ...@@ -66,6 +66,8 @@ class Optimizer(object):
if use_sparse_remote_updater: if use_sparse_remote_updater:
gradient_machine.prefetch(in_args) gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote() parameter_updater.getParametersRemote()
:param pserver_spec: pserver location, eg: localhost:3000
:return: parameter_updater :return: parameter_updater
""" """
if is_local: if is_local:
......
...@@ -41,6 +41,7 @@ class SGD(object): ...@@ -41,6 +41,7 @@ class SGD(object):
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
:param extra_layers: Some layers in the neural network graph are not :param extra_layers: Some layers in the neural network graph are not
in the path of cost layer. in the path of cost layer.
:param pserver_spec: pserver location, eg: localhost:3000
:type extra_layers: paddle.v2.config_base.Layer :type extra_layers: paddle.v2.config_base.Layer
""" """
......
...@@ -29,7 +29,9 @@ setup(name='paddle', ...@@ -29,7 +29,9 @@ setup(name='paddle',
description='Parallel Distributed Deep Learning', description='Parallel Distributed Deep Learning',
install_requires=setup_requires, install_requires=setup_requires,
packages=packages, packages=packages,
package_data={'paddle.v2.master': ['libpaddle_master.so'], }, package_data={'paddle.v2.master': ['libpaddle_master.so'],
'paddle.v2.framework': ['core.so']
},
package_dir={ package_dir={
'': '${CMAKE_CURRENT_SOURCE_DIR}', '': '${CMAKE_CURRENT_SOURCE_DIR}',
# The paddle.v2.framework.proto will be generated while compiling. # The paddle.v2.framework.proto will be generated while compiling.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册