提交 b861c019 编写于 作者: Y Yu Yang 提交者: GitHub

Merge branch 'develop' into feature/uniform_random_op

...@@ -38,12 +38,11 @@ RUN apt-get update && \ ...@@ -38,12 +38,11 @@ RUN apt-get update && \
RUN pip --no-cache-dir install 'numpy>=1.12.0' RUN pip --no-cache-dir install 'numpy>=1.12.0'
# Install Go and glide # Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \ RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \
tar -C /usr/local -xzf go.tgz && \ tar -xz -C /usr/local && \
mkdir /root/gopath && \ mkdir /root/gopath && \
mkdir /root/gopath/bin && \ mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \ mkdir /root/gopath/src
rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT. # should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
......
...@@ -8,7 +8,7 @@ ExternalProject_Add( ...@@ -8,7 +8,7 @@ ExternalProject_Add(
extern_lib_any extern_lib_any
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/PaddlePaddle/any.git" GIT_REPOSITORY "https://github.com/PaddlePaddle/any.git"
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020" GIT_TAG "15595d8324be9e8a9a80d9ae442fdd12bd66df5d"
PREFIX ${ANY_SOURCE_DIR} PREFIX ${ANY_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -43,8 +43,8 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") ...@@ -43,8 +43,8 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
INCLUDE_DIRECTORIES(${MKLML_INC_DIR}) INCLUDE_DIRECTORIES(${MKLML_INC_DIR})
SET(mklml_cmakefile ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt) FILE(WRITE ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt
FILE(WRITE ${mklml_cmakefile} "PROJECT(MKLML)\n" "PROJECT(MKLML)\n"
"cmake_minimum_required(VERSION 3.0)\n" "cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${MKLML_VER}\n" "install(DIRECTORY ${MKLML_VER}\n"
" DESTINATION ${MKLML_DST_DIR})\n") " DESTINATION ${MKLML_DST_DIR})\n")
...@@ -54,8 +54,7 @@ ExternalProject_Add( ...@@ -54,8 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${MKLML_SOURCE_DIR} PREFIX ${MKLML_SOURCE_DIR}
DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR} DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate -O ${MKLML_DOWNLOAD_DIR}/${MKLML_VER}.tgz ${MKLML_URL} DOWNLOAD_COMMAND wget --no-check-certificate -qO- ${MKLML_URL} | tar xz -C ${MKLML_DOWNLOAD_DIR}
&& tar -xzf ${MKLML_DOWNLOAD_DIR}/${MKLML_VER}.tgz
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT}
......
...@@ -3,6 +3,43 @@ PaddlePaddle的Docker容器使用方式 ...@@ -3,6 +3,43 @@ PaddlePaddle的Docker容器使用方式
PaddlePaddle目前唯一官方支持的运行的方式是Docker容器。因为Docker能在所有主要操作系统(包括Linux,Mac OS X和Windows)上运行。 请注意,您需要更改 `Dockers设置 <https://github.com/PaddlePaddle/Paddle/issues/627>`_ 才能充分利用Mac OS X和Windows上的硬件资源。 PaddlePaddle目前唯一官方支持的运行的方式是Docker容器。因为Docker能在所有主要操作系统(包括Linux,Mac OS X和Windows)上运行。 请注意,您需要更改 `Dockers设置 <https://github.com/PaddlePaddle/Paddle/issues/627>`_ 才能充分利用Mac OS X和Windows上的硬件资源。
Docker使用入门
------------------------------
几个基础的概念帮助理解和使用Docker:
- *镜像*:一个Docker镜像是一个打包好的软件。它包含了这个软件本身和它所依赖的运行环境。PaddlePaddle的Docker镜像就包含了PaddlePaddle的Python库以及其依赖的多个Python库。这样我们可以直接在Docker中运行需要的程序而不需要安装后在执行。可以执行:
.. code-block:: bash
docker images
来列出当前系统中的所有镜像,同样可以执行:
.. code-block:: bash
docker pull paddlepaddle/paddle:0.10.0
来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。
- *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。
实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。
可以执行:
.. code-block:: bash
docker run paddlepaddle/paddle:0.10.0
来使用一个镜像启动一个容器。
- 默认情况下,Docker容器会运行在独立的文件系统空间之上,我们无法在Docker容器中
访问到主机上的文件。可以通过*挂载Volume*的方式,将主机上的文件或目录挂载到
Docker容器中。下面的命令把当前目录挂载到了容器中的 /data 目录下,容器使用
debian镜像,并且启动后执行 :code:`ls /data`。
.. code-block:: bash
docker run --rm -v $(pwd):/data debian ls /data
PaddlePaddle发布的Docker镜像使用说明 PaddlePaddle发布的Docker镜像使用说明
------------------------------ ------------------------------
...@@ -12,11 +49,11 @@ PaddlePaddle需要的所有编译工具。把编译出来的PaddlePaddle也打 ...@@ -12,11 +49,11 @@ PaddlePaddle需要的所有编译工具。把编译出来的PaddlePaddle也打
像,称为生产镜像,里面涵盖了PaddlePaddle运行所需的所有环境。每次 像,称为生产镜像,里面涵盖了PaddlePaddle运行所需的所有环境。每次
PaddlePaddle发布新版本的时候都会发布对应版本的生产镜像以及开发镜像。运 PaddlePaddle发布新版本的时候都会发布对应版本的生产镜像以及开发镜像。运
行镜像包括纯CPU版本和GPU版本以及其对应的非AVX版本。我们会在 行镜像包括纯CPU版本和GPU版本以及其对应的非AVX版本。我们会在
`dockerhub.com <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_ 提供最新 `dockerhub.com <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_
的Docker镜像,可以在"tags"标签下找到最新的Paddle镜像版本。为了方便在国 和国内镜像`docker.paddlepaddle.org` 提供最新
内的开发者下载Docker镜像,我们提供了国内的镜像服务器供大家使用。如果您 的Docker镜像,可以在"tags"标签下找到最新的Paddle镜像版本。
在国内,请把文档里命令中的paddlepaddle/paddle替换成
docker.paddlepaddle.org/paddle。 **注意:为了方便在国内的开发者下载Docker镜像,我们提供了国内的镜像服务器供大家使用。如果您在国内,请把文档里命令中的paddlepaddle/paddle替换成docker.paddlepaddle.org/paddle。**
1. 开发镜像::code:`paddlepaddle/paddle:0.10.0-dev` 1. 开发镜像::code:`paddlepaddle/paddle:0.10.0-dev`
...@@ -68,6 +105,8 @@ docker.paddlepaddle.org/paddle。 ...@@ -68,6 +105,8 @@ docker.paddlepaddle.org/paddle。
如果输出是No,就需要选择使用no-AVX的镜像 如果输出是No,就需要选择使用no-AVX的镜像
**注:在0.10.0之后的版本,PaddlePaddle都可以自动判断硬件是否支持AVX,所以无需判断AVX即可使用**
以上方法在GPU镜像里也能用,只是请不要忘记提前在物理机上安装GPU最新驱动。 以上方法在GPU镜像里也能用,只是请不要忘记提前在物理机上安装GPU最新驱动。
为了保证GPU驱动能够在镜像里面正常运行,我们推荐使用[nvidia-docker](https://github.com/NVIDIA/nvidia-docker)来运行镜像。 为了保证GPU驱动能够在镜像里面正常运行,我们推荐使用[nvidia-docker](https://github.com/NVIDIA/nvidia-docker)来运行镜像。
......
...@@ -63,12 +63,35 @@ CPU-only version and a CUDA GPU version and their no-AVX versions. ...@@ -63,12 +63,35 @@ CPU-only version and a CUDA GPU version and their no-AVX versions.
We put the docker images on `dockerhub.com We put the docker images on `dockerhub.com
<https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_. You can find the <https://hub.docker.com/r/paddlepaddle/paddle/tags/>`_. You can find the
latest versions under "tags" tab at dockerhub.com. If you are in latest versions under "tags" tab at dockerhub.com.
China, you can use our Docker image registry mirror to speed up the
download process. To use it, please replace all paddlepaddle/paddle in
the commands to docker.paddlepaddle.org/paddle.
1. Production images, this image might have multiple variants: ** NOTE: If you are in China, you can use our Docker image registry mirror to speed up the download process. To use it, please replace all paddlepaddle/paddle in the commands to docker.paddlepaddle.org/paddle.**
1. development image :code:`paddlepaddle/paddle:<version>-dev`
This image has packed related develop tools and runtime
environment. Users and developers can use this image instead of
their own local computer to accomplish development, build,
releasing, document writing etc. While different version of paddle
may depends on different version of libraries and tools, if you
want to setup a local environment, you must pay attention to the
versions. The development image contains:
- gcc/clang
- nvcc
- Python
- sphinx
- woboq
- sshd
Many developers use servers with GPUs, they can use ssh to login to
the server and run :code:`docker exec` to enter the docker
container and start their work. Also they can start a development
docker image with SSHD service, so they can login to the container
and start work.
2. Production images, this image might have multiple variants:
- GPU/AVX::code:`paddlepaddle/paddle:<version>-gpu` - GPU/AVX::code:`paddlepaddle/paddle:<version>-gpu`
- GPU/no-AVX::code:`paddlepaddle/paddle:<version>-gpu-noavx` - GPU/no-AVX::code:`paddlepaddle/paddle:<version>-gpu-noavx`
...@@ -84,7 +107,7 @@ the commands to docker.paddlepaddle.org/paddle. ...@@ -84,7 +107,7 @@ the commands to docker.paddlepaddle.org/paddle.
if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi
**NOTE:versions after 0.10.0 will automatically detect system AVX support, so manual detect is not needed in this case.**
To run the CPU-only image as an interactive container: To run the CPU-only image as an interactive container:
.. code-block:: bash .. code-block:: bash
...@@ -103,29 +126,6 @@ the commands to docker.paddlepaddle.org/paddle. ...@@ -103,29 +126,6 @@ the commands to docker.paddlepaddle.org/paddle.
nvidia-docker run -it --rm paddlepaddle/paddle:0.10.0-gpu /bin/bash nvidia-docker run -it --rm paddlepaddle/paddle:0.10.0-gpu /bin/bash
2. development image :code:`paddlepaddle/paddle:<version>-dev`
This image has packed related develop tools and runtime
environment. Users and developers can use this image instead of
their own local computer to accomplish development, build,
releasing, document writing etc. While different version of paddle
may depends on different version of libraries and tools, if you
want to setup a local environment, you must pay attention to the
versions. The development image contains:
- gcc/clang
- nvcc
- Python
- sphinx
- woboq
- sshd
Many developers use servers with GPUs, they can use ssh to login to
the server and run :code:`docker exec` to enter the docker
container and start their work. Also they can start a development
docker image with SSHD service, so they can login to the container
and start work.
Train Model Using Python API Train Model Using Python API
---------------------------- ----------------------------
......
...@@ -32,7 +32,7 @@ import ( ...@@ -32,7 +32,7 @@ import (
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout") dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
...@@ -60,12 +60,12 @@ func main() { ...@@ -60,12 +60,12 @@ func main() {
idx, err = e.Register(*port) idx, err = e.Register(*port)
candy.Must(err) candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil { if err != nil {
if err == pserver.ErrCheckpointNotFound { if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.") log.Infof("Could not find the pserver checkpoint.")
} else { } else {
log.Errorf("Fetch checkpoint failed, %s", err) panic(err)
} }
} }
} }
......
hash: 2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-07-29T07:34:48.722757905+08:00 updated: 2017-08-03T21:46:51.744995189Z
imports: imports:
- name: github.com/beorn7/perks - name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...@@ -145,6 +145,8 @@ imports: ...@@ -145,6 +145,8 @@ imports:
version: a1dba9ce8baed984a2495b658c82687f8157b98f version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages: subpackages:
- xfs - xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: a3f95b5c423586578a4e099b11a46c2479628cac version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy - name: github.com/topicai/candy
......
...@@ -14,11 +14,13 @@ import: ...@@ -14,11 +14,13 @@ import:
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy - package: github.com/topicai/candy
- package: golang.org/x/crypto - package: golang.org/x/crypto
vcs: git
repo: https://github.com/golang/crypto.git repo: https://github.com/golang/crypto.git
- package: golang.org/x/sys
vcs: git vcs: git
- package: golang.org/x/sys
repo: https://github.com/golang/sys.git repo: https://github.com/golang/sys.git
- package: golang.org/x/text
vcs: git vcs: git
- package: golang.org/x/text
repo: https://github.com/golang/text.git repo: https://github.com/golang/text.git
vcs: git
- package: github.com/satori/go.uuid
version: v1.1.0
...@@ -77,11 +77,12 @@ type taskEntry struct { ...@@ -77,11 +77,12 @@ type taskEntry struct {
NumFailure int NumFailure int
} }
type taskQueues struct { type masterState struct {
Todo []taskEntry Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry Done []taskEntry
Failed []taskEntry Failed []taskEntry
CurPass int
} }
// Service is the master server service. // Service is the master server service.
...@@ -95,10 +96,10 @@ type Service struct { ...@@ -95,10 +96,10 @@ type Service struct {
initDone bool initDone bool
mu sync.Mutex mu sync.Mutex
taskQueues taskQueues // State to be persisted to snapshot.
currPass int state masterState
jobTasks []taskEntry // The trainer that is currently saving model. This state is
// transient, does not need to be persisted to snapshot.
savingTrainer string savingTrainer string
} }
...@@ -141,8 +142,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur ...@@ -141,8 +142,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur
s.chunksPerTask = chunksPerTask s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.failureMax = failureMax s.failureMax = failureMax
s.taskQueues = taskQueues{} s.state = masterState{}
s.taskQueues.Pending = make(map[int]taskEntry) s.state.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{}) s.ready = make(chan struct{})
s.store = store s.store = store
recovered, err := s.recover() recovered, err := s.recover()
...@@ -180,7 +181,7 @@ func (s *Service) recover() (bool, error) { ...@@ -180,7 +181,7 @@ func (s *Service) recover() (bool, error) {
} }
dec := gob.NewDecoder(gr) dec := gob.NewDecoder(gr)
var tqs taskQueues var tqs masterState
err = dec.Decode(&tqs) err = dec.Decode(&tqs)
if err != nil { if err != nil {
return false, err return false, err
...@@ -193,7 +194,12 @@ func (s *Service) recover() (bool, error) { ...@@ -193,7 +194,12 @@ func (s *Service) recover() (bool, error) {
log.Errorln(err) log.Errorln(err)
} }
s.taskQueues = tqs s.state = tqs
log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.")
for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
}
return true, nil return true, nil
} }
...@@ -208,7 +214,7 @@ func (s *Service) snapshot() error { ...@@ -208,7 +214,7 @@ func (s *Service) snapshot() error {
var buf bytes.Buffer var buf bytes.Buffer
gw := gzip.NewWriter(&buf) gw := gzip.NewWriter(&buf)
enc := gob.NewEncoder(gw) enc := gob.NewEncoder(gw)
err := enc.Encode(s.taskQueues) err := enc.Encode(s.state)
if err != nil { if err != nil {
return err return err
} }
...@@ -290,8 +296,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error { ...@@ -290,8 +296,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
return err return err
} }
s.jobTasks = partition(chunks, s.chunksPerTask) s.state.Todo = partition(chunks, s.chunksPerTask)
s.taskQueues.Todo = s.jobTasks
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
...@@ -319,17 +324,17 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { ...@@ -319,17 +324,17 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
} }
}() }()
delete(s.taskQueues.Pending, t.Task.Meta.ID) delete(s.state.Pending, t.Task.Meta.ID)
t.NumFailure++ t.NumFailure++
if t.NumFailure > s.failureMax { if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t) s.state.Failed = append(s.state.Failed, t)
return return
} }
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure) log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.state.Todo = append(s.state.Todo, t)
return return
} }
...@@ -338,7 +343,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -338,7 +343,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID] t, ok := s.state.Pending[taskID]
if !ok { if !ok {
return return
} }
...@@ -350,10 +355,11 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -350,10 +355,11 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
// must be called with lock held. // must be called with lock held.
func (s *Service) logFields() log.Fields { func (s *Service) logFields() log.Fields {
return log.Fields{ return log.Fields{
"todoLen": len(s.taskQueues.Todo), "todoLen": len(s.state.Todo),
"pendingLen": len(s.taskQueues.Pending), "pendingLen": len(s.state.Pending),
"doneLen": len(s.taskQueues.Done), "doneLen": len(s.state.Done),
"failedLen": len(s.taskQueues.Failed), "failedLen": len(s.state.Failed),
"curPass": s.state.CurPass,
} }
} }
...@@ -366,17 +372,17 @@ func (s *Service) GetTask(passID int, task *Task) error { ...@@ -366,17 +372,17 @@ func (s *Service) GetTask(passID int, task *Task) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if passID < s.currPass { if passID < s.state.CurPass {
return ErrPassBefore return ErrPassBefore
} }
if passID > s.currPass { if passID > s.state.CurPass {
// Client may get run to pass after master when one client faster than the // Client may get run to pass after master when one client faster than the
// other // other
return ErrPassAfter return ErrPassAfter
} }
if len(s.taskQueues.Todo) == 0 { if len(s.state.Todo) == 0 {
if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 { if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass") log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
return ErrAllTaskFailed return ErrAllTaskFailed
} }
...@@ -384,10 +390,10 @@ func (s *Service) GetTask(passID int, task *Task) error { ...@@ -384,10 +390,10 @@ func (s *Service) GetTask(passID int, task *Task) error {
return ErrNoMoreAvailable return ErrNoMoreAvailable
} }
t := s.taskQueues.Todo[0] t := s.state.Todo[0]
t.Task.Meta.Epoch++ t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] s.state.Todo = s.state.Todo[1:]
s.taskQueues.Pending[t.Task.Meta.ID] = t s.state.Pending[t.Task.Meta.ID] = t
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
return err return err
...@@ -409,7 +415,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -409,7 +415,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID] t, ok := s.state.Pending[taskID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return nil return nil
...@@ -417,18 +423,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -417,18 +423,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
// task finished, reset timeout // task finished, reset timeout
t.NumFailure = 0 t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.state.Done = append(s.state.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.state.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 { if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
// increase master side pass count if all tasks finished // increase master side pass count if all tasks finished
s.currPass++ s.state.CurPass++
s.taskQueues.Todo = s.jobTasks s.state.Todo = append(s.state.Done, s.state.Failed...)
s.taskQueues.Done = []taskEntry{} s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks // TODO(typhoonzero): deal with failed tasks
s.taskQueues.Failed = []taskEntry{} s.state.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass) log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass)
} }
err := s.snapshot() err := s.snapshot()
...@@ -447,7 +453,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { ...@@ -447,7 +453,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[meta.ID] t, ok := s.state.Pending[meta.ID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta) log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil return nil
......
...@@ -59,7 +59,7 @@ func initClient() [numPserver]int { ...@@ -59,7 +59,7 @@ func initClient() [numPserver]int {
go func(l net.Listener) { go func(l net.Listener) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -103,7 +103,7 @@ func (p *EtcdClient) List() []Server { ...@@ -103,7 +103,7 @@ func (p *EtcdClient) List() []Server {
time.Sleep(p.timeout) time.Sleep(p.timeout)
continue continue
} }
log.Infof("got value (%s) for key: %s", psAddr, psKey) log.Debugf("got value (%s) for key: %s", psAddr, psKey)
servers[i].Index = i servers[i].Index = i
servers[i].Addr = psAddr servers[i].Addr = psAddr
} }
......
...@@ -206,6 +206,7 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ...@@ -206,6 +206,7 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
if err != nil { if err != nil {
return []byte{}, err return []byte{}, err
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
return []byte{}, nil return []byte{}, nil
...@@ -215,9 +216,14 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ...@@ -215,9 +216,14 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
} }
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease())) var err error
if withLease {
_, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
} else {
_, err = e.client.Put(ctx, key, string(value))
}
cancel() cancel()
return err return err
} }
......
...@@ -32,6 +32,7 @@ type optimizer struct { ...@@ -32,6 +32,7 @@ type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
elementType ElementType elementType ElementType
contentLen int contentLen int
config []byte
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
...@@ -70,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -70,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0]) cstate = unsafe.Pointer(&s[0])
} }
o.config = c
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s))) C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o return o
......
...@@ -25,11 +25,13 @@ import ( ...@@ -25,11 +25,13 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path"
"strconv" "strconv"
"sync" "sync"
"time" "time"
uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -44,7 +46,7 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found") ...@@ -44,7 +46,7 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found")
const ( const (
AlreadyInitialized = "pserver already initialized" AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized" Uninitialized = "pserver not fully initialized"
CheckpointMD5Failed = "checkpoint file MD5 validation failed" WrongChecksum = "checkpoint file checksum validation failed"
) )
// Supported element types. // Supported element types.
...@@ -73,11 +75,12 @@ type ParameterWithConfig struct { ...@@ -73,11 +75,12 @@ type ParameterWithConfig struct {
// checkpointMeta saves checkpoint metadata // checkpointMeta saves checkpoint metadata
type checkpointMeta struct { type checkpointMeta struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Path string `json:"path"`
MD5 string `json:"md5"` MD5 string `json:"md5"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }
// Checkpoint is the pserver shard persist in file // Checkpoint is the pserver shard persist in file.
type Checkpoint []parameterCheckpoint type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
...@@ -90,50 +93,58 @@ type Service struct { ...@@ -90,50 +93,58 @@ type Service struct {
checkpointInterval time.Duration checkpointInterval time.Duration
checkpointPath string checkpointPath string
client *EtcdClient client *EtcdClient
mu sync.Mutex mu sync.Mutex
optMap map[string]*optimizer optMap map[string]*optimizer
} }
// parameterCheckpoint saves parameter checkpoint // parameterCheckpoint saves parameter checkpoint.
type parameterCheckpoint struct { type parameterCheckpoint struct {
ParameterWithConfig ParameterWithConfig
State []byte State []byte
} }
// NewCheckpointFromFile loads parameters and state from checkpoint file func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) { v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil { if err != nil {
return nil, err return
} }
if len(v) == 0 { if len(v) == 0 {
return nil, ErrCheckpointNotFound err = ErrCheckpointNotFound
return
} }
var cpMeta checkpointMeta if err = json.Unmarshal(v, &meta); err != nil {
if err = json.Unmarshal(v, &cpMeta); err != nil { return
return nil, err
} }
fn := filepath.Join(cpPath, cpMeta.UUID) return
if _, err = os.Stat(fn); os.IsNotExist(err) { }
// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
cpMeta, err := loadMeta(e, idx)
if err != nil {
return nil, err return nil, err
} }
content, err := ioutil.ReadFile(fn)
content, err := ioutil.ReadFile(cpMeta.Path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h := md5.New() h := md5.New()
md5 := hex.EncodeToString(h.Sum(content)) md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 { if md5 != cpMeta.MD5 {
return nil, errors.New(CheckpointMD5Failed) return nil, errors.New(WrongChecksum)
} }
dec := gob.NewDecoder(bytes.NewReader(content)) dec := gob.NewDecoder(bytes.NewReader(content))
cp := Checkpoint{} var cp Checkpoint
if err = dec.Decode(cp); err != nil { if err = dec.Decode(&cp); err != nil {
return nil, err return nil, err
} }
return cp, nil return cp, nil
...@@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error { ...@@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
} }
close(s.initialized) close(s.initialized)
go func() {
t := time.Tick(s.checkpointInterval)
for range t {
err := s.checkpoint()
if err != nil {
log.Errorln(err)
}
}
}()
return nil return nil
} }
...@@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil return nil
} }
// pserver save checkpoint func traceTime(start time.Time, name string) {
func (s *Service) doCheckpoint() (err error) { elapsed := time.Since(start)
<-s.initialized log.Infof("%s took %v", name, elapsed)
s.mu.Lock() }
defer s.mu.Unlock()
// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
log.Infoln("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock()
cp := make([]parameterCheckpoint, len(s.optMap)) cp := make([]parameterCheckpoint, len(s.optMap))
index := 0 index := 0
// TODO(helin): write checkpoint incrementally to reduce memory
// footprint during checkpoint.
for name, opt := range s.optMap { for name, opt := range s.optMap {
var pc parameterCheckpoint var pc parameterCheckpoint
pc.Param.Name = name pc.Param.Name = name
pc.Param.ElementType = opt.elementType pc.Param.ElementType = opt.elementType
pc.Param.Content = opt.GetWeights() pc.Param.Content = opt.GetWeights()
pc.Config = opt.config
pc.State = opt.GetStates() pc.State = opt.GetStates()
cp[index] = pc cp[index] = pc
index++ index++
} }
s.mu.Unlock()
var buf bytes.Buffer var buf bytes.Buffer
encoder := gob.NewEncoder(&buf) encoder := gob.NewEncoder(&buf)
err = encoder.Encode(cp) err = encoder.Encode(cp)
...@@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) { ...@@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) {
return return
} }
cpMeta := checkpointMeta{} id := uuid.NewV4().String()
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) p := path.Join(s.checkpointPath, id)
cpMeta.Timestamp = time.Now().UnixNano() f, err := os.Create(p)
h := md5.New()
cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, err := json.Marshal(cpMeta)
if err != nil {
return
}
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
if err != nil {
return
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
if err != nil {
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
} else {
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
}
f, err := os.Create(cpMeta.UUID)
if err != nil { if err != nil {
return return
} }
...@@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) { ...@@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) {
return return
} }
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Infoln("Do not have existing checkpoint.")
err = nil
}
if err != nil {
return
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
MD5: md5,
Path: p,
}
json, err := json.Marshal(cpMeta)
if err != nil {
return
}
err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
if err != nil {
return
}
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Errorln(rmErr)
}
}
return return
} }
...@@ -30,7 +30,7 @@ const ( ...@@ -30,7 +30,7 @@ const (
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -102,7 +102,7 @@ func TestServiceFull(t *testing.T) { ...@@ -102,7 +102,7 @@ func TestServiceFull(t *testing.T) {
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -119,7 +119,7 @@ func TestMultipleInit(t *testing.T) { ...@@ -119,7 +119,7 @@ func TestMultipleInit(t *testing.T) {
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.Fatal(err) t.Fatal(err)
...@@ -128,7 +128,7 @@ func TestUninitialized(t *testing.T) { ...@@ -128,7 +128,7 @@ func TestUninitialized(t *testing.T) {
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
......
...@@ -39,6 +39,7 @@ set(CUDA_CU_SOURCES ...@@ -39,6 +39,7 @@ set(CUDA_CU_SOURCES
src/hl_cuda_lstm.cu src/hl_cuda_lstm.cu
src/hl_top_k.cu src/hl_top_k.cu
src/hl_batch_transpose.cu src/hl_batch_transpose.cu
src/hl_batch_norm.cu
src/hl_cuda_sequence.cu src/hl_cuda_sequence.cu
src/hl_table_apply.cu) src/hl_table_apply.cu)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_BATCH_NORM_H_
#define HL_BATCH_NORM_H_
#include "hl_base.h"
/**
* @brief batch norm inferece.
*
* @param[in] input input data.
* @param[out] output output data.
* @param[in] scale batch normalization scale parameter (in original
* paper scale is referred to as gamma).
* @param[in] bias batch normalization bias parameter (in original
* paper scale is referred to as beta).
* @param[in] estimatedMean
* @param[in] estimatedVar The moving mean and variance
* accumulated during the training phase are passed
* as inputs here.
* @param[in] epsilon Epsilon value used in the batch
* normalization formula.
*/
extern void hl_batch_norm_cuda_inference(const real* input,
real* output,
const real* scale,
const real* bias,
const real* estimatedMean,
const real* estimatedVar,
const double epsilon,
size_t batchSize,
size_t channel,
size_t height,
size_t width);
#endif // HL_BATCH_NORM_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_batch_norm.h"
__global__ void batchNormInference(real* output,
const real* input,
const real* scale,
const real* bias,
const real* estimatedMean,
const real* estimatedVar,
const double epsilon,
size_t batchSize,
size_t channel,
size_t height,
size_t width) {
const int tid = threadIdx.x;
const int num = channel * height * width;
const int batch = blockIdx.x;
for (int i = tid; i < num; i += blockDim.x) {
const int c = i / (height * width);
const int id = batch * num + i;
real val = input[id] - estimatedMean[c];
val /= sqrt(estimatedVar[c] + epsilon);
val *= scale[c];
val += bias[c];
output[id] = val;
}
}
void hl_batch_norm_cuda_inference(const real* input,
real* output,
const real* scale,
const real* bias,
const real* estimatedMean,
const real* estimatedVar,
const double epsilon,
size_t batchSize,
size_t channel,
size_t height,
size_t width) {
batchNormInference<<<batchSize, 256, 0, STREAM_DEFAULT>>>(output,
input,
scale,
bias,
estimatedMean,
estimatedVar,
epsilon,
batchSize,
channel,
height,
width);
CHECK_SYNC("hl_batch_norm_cuda_inference failed!");
}
...@@ -1023,14 +1023,6 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, ...@@ -1023,14 +1023,6 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
real beta = 1.0f; real beta = 1.0f;
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size;
if (batch_size > 1024 && g_cudnn_lib_version < 6000) {
LOG(INFO) << " To process current batch data with size " << batch_size
<< " (>1024), cudnnBatchNorm requires cuDNN version >= 6000."
<< " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED,"
<< " just recompile PaddlePaddle with cuDNN >= 6000, replacing"
<< " current version " << g_cudnn_lib_version;
}
CHECK_CUDNN( CHECK_CUDNN(
dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle,
mode, mode,
......
...@@ -44,4 +44,5 @@ cc_library(paddle_pybind SHARED ...@@ -44,4 +44,5 @@ cc_library(paddle_pybind SHARED
mean_op mean_op
cross_entropy_op cross_entropy_op
recurrent_op recurrent_op
uniform_random_op) uniform_random_op
fill_zeros_like_op)
...@@ -40,6 +40,7 @@ USE_OP(mean); ...@@ -40,6 +40,7 @@ USE_OP(mean);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
USE_OP(rowwise_add); USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_OP_WITHOUT_KERNEL(recurrent_op); USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(uniform_random); USE_OP(uniform_random);
namespace paddle { namespace paddle {
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "CudnnBatchNormLayer.h" #include "CudnnBatchNormLayer.h"
#include "Layer.h" #include "Layer.h"
#include "paddle/cuda/include/hl_batch_norm.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
namespace paddle { namespace paddle {
...@@ -79,6 +80,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { ...@@ -79,6 +80,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
savedInvVar); savedInvVar);
} else { } else {
// used movingMean and movingVar in testing // used movingMean and movingVar in testing
if (batchSize <= 1024) {
hl_batch_norm_forward_inference(ioDesc_, hl_batch_norm_forward_inference(ioDesc_,
input, input,
ioDesc_, ioDesc_,
...@@ -89,6 +91,22 @@ void CudnnBatchNormLayer::forward(PassType passType) { ...@@ -89,6 +91,22 @@ void CudnnBatchNormLayer::forward(PassType passType) {
movingMean, movingMean,
movingVar, movingVar,
EPS); EPS);
} else {
// There is a limitation in cudnn library.
// When the batch size is larger than 1024 in cuDNN v5.1,
// the cudnnBatchNormalizationForwardInference will fail.
hl_batch_norm_cuda_inference(input,
output,
gamma,
beta,
movingMean,
movingVar,
EPS,
batchSize,
channels_,
imageH_,
imageW_);
}
} }
/* activation */ { /* activation */ {
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/GlobalConstants.h"
#include "LayerGradUtil.h" #include "LayerGradUtil.h"
#include "paddle/cuda/include/hl_batch_norm.h"
#include "paddle/math/tests/TensorCheck.h"
#include "paddle/testing/TestUtil.h" #include "paddle/testing/TestUtil.h"
using namespace paddle; // NOLINT using namespace paddle; // NOLINT
...@@ -117,6 +119,74 @@ TEST(Layer, batchNorm) { ...@@ -117,6 +119,74 @@ TEST(Layer, batchNorm) {
CHECK_EQ(static_cast<int>(convLayer->getOutputValue()->getWidth()), 576); CHECK_EQ(static_cast<int>(convLayer->getOutputValue()->getWidth()), 576);
} }
#ifndef PADDLE_ONLY_CPU
void batchNormInference(int n, int c, int h, int w) {
MatrixPtr input = std::make_shared<GpuMatrix>(n, c * h * w);
MatrixPtr cudnnOut = std::make_shared<GpuMatrix>(n, c * h * w);
MatrixPtr cudaOut = std::make_shared<GpuMatrix>(n, c * h * w);
MatrixPtr cudnnCheck = std::make_shared<CpuMatrix>(n, c * h * w);
MatrixPtr cudaCheck = std::make_shared<CpuMatrix>(n, c * h * w);
input->randomizeUniform();
cudnnOut->zeroMem();
cudaOut->zeroMem();
MatrixPtr scale = std::make_shared<GpuMatrix>(1, c);
scale->randomizeUniform();
MatrixPtr bias = std::make_shared<GpuMatrix>(1, c);
bias->randomizeUniform();
MatrixPtr movingMean = std::make_shared<GpuMatrix>(1, c);
movingMean->randomizeUniform();
MatrixPtr movingVar = std::make_shared<GpuMatrix>(1, c);
movingVar->randomizeUniform();
movingVar->clip(0.01, 50);
hl_tensor_descriptor ioDesc;
hl_tensor_descriptor bnDesc;
hl_create_tensor_descriptor(&ioDesc);
hl_create_tensor_descriptor(&bnDesc);
hl_tensor_reshape(ioDesc, n, c, h, w);
hl_tensor_reshape(bnDesc, 1, c, 1, 1);
double EPS = 1E-5;
hl_batch_norm_forward_inference(ioDesc,
input->getData(),
ioDesc,
cudnnOut->getData(),
bnDesc,
scale->getData(),
bias->getData(),
movingMean->getData(),
movingVar->getData(),
EPS);
hl_batch_norm_cuda_inference(input->getData(),
cudaOut->getData(),
scale->getData(),
bias->getData(),
movingMean->getData(),
movingVar->getData(),
EPS,
n,
c,
h,
w);
cudnnCheck->copyFrom(*cudnnOut);
cudaCheck->copyFrom(*cudaOut);
autotest::TensorCheckErr(*cudnnCheck, *cudaCheck);
hl_destroy_tensor_descriptor(ioDesc);
hl_destroy_tensor_descriptor(bnDesc);
}
TEST(BatchNorm, Inference) {
batchNormInference(33, 267, 1, 1);
batchNormInference(19, 105, 4, 4);
}
#endif
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
initMain(argc, argv); initMain(argc, argv);
......
...@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/fill_zeros_like_op.h" #include "paddle/operators/fill_zeros_like_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h" #include "paddle/operators/fill_zeros_like_op.h"
......
...@@ -13,9 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +24,8 @@ class FillZerosLikeKernel : public framework::OpKernel { ...@@ -26,7 +24,8 @@ class FillZerosLikeKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>(0); auto* output = context.Output<framework::Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).setZero(); auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(T(0));
} }
}; };
......
...@@ -6,4 +6,5 @@ cc_library(paddle_pybind SHARED ...@@ -6,4 +6,5 @@ cc_library(paddle_pybind SHARED
add_op add_op
mean_op mean_op
cross_entropy_op cross_entropy_op
recurrent_op) recurrent_op
fill_zeros_like_op)
...@@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py) ...@@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py)
py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py) py_test(gradient_checker SRCS gradient_checker.py)
......
import unittest
from op_test_util import OpTestMeta
import numpy
class TestFillZerosLikeOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "fill_zeros_like"
self.inputs = {'Src': numpy.random.random((219, 232)).astype("float32")}
self.outputs = {'Dst': numpy.zeros_like(self.inputs['Src'])}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册