提交 67481ca8 编写于 作者: Y Yi Wang

Merge branch 'develop' of https://github.com/paddlepaddle/paddle into memory_cpu_allocator

...@@ -2,7 +2,6 @@ group: deprecated-2017Q2 ...@@ -2,7 +2,6 @@ group: deprecated-2017Q2
language: cpp language: cpp
cache: cache:
directories: directories:
- $HOME/third_party
- $HOME/.ccache - $HOME/.ccache
- $HOME/.cache/pip - $HOME/.cache/pip
sudo: required sudo: required
...@@ -10,15 +9,13 @@ dist: trusty ...@@ -10,15 +9,13 @@ dist: trusty
os: os:
- linux - linux
env: env:
- JOB=DOCS - JOB=build_doc
- JOB=BUILD_AND_TEST - JOB=check_style
- JOB=PRE_COMMIT
addons: addons:
apt: apt:
packages: packages:
- gcc-4.8 - gcc-4.8
- g++-4.8 - g++-4.8
- gfortran-4.8
- git - git
- build-essential - build-essential
- python - python
...@@ -35,18 +32,7 @@ addons: ...@@ -35,18 +32,7 @@ addons:
- libtool - libtool
- ccache - ccache
before_install: before_install:
- | - if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
if [ ${JOB} == "BUILD_AND_TEST" ]; then
local change_list=`git diff --name-only $TRAVIS_COMMIT_RANGE`
if [ $? -eq 0 ]; then # if git diff return no zero, then rerun unit test.
if ! echo ${change_list} | grep -qvE '(\.md$)|(\.rst$)|(\.jpg$)|(\.png$)'
then
echo "Only markdown docs were updated, stopping build process."
exit
fi
fi
fi
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version. # protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker - pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
...@@ -55,9 +41,7 @@ before_install: ...@@ -55,9 +41,7 @@ before_install:
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
- | - paddle/scripts/travis/$JOB.sh
timeout 2580 paddle/scripts/travis/main.sh # 43min timeout
RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi;
notifications: notifications:
email: email:
on_success: change on_success: change
......
...@@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/ ...@@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y \ apt-get install -y \
git python-pip python-dev openssh-server bison \ git python-pip python-dev openssh-server bison \
wget unzip tar xz-utils bzip2 gzip coreutils \ wget unzip tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \ curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-numpy python-matplotlib gcc g++ \ python-numpy python-matplotlib gcc g++ \
automake locales clang-format-3.8 swig doxygen cmake \ automake locales clang-format-3.8 swig doxygen cmake \
......
...@@ -21,7 +21,8 @@ IF(NOT ${CBLAS_FOUND}) ...@@ -21,7 +21,8 @@ IF(NOT ${CBLAS_FOUND})
SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE)
SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${LIBRARY_PREFIX}openblas${STATIC_LIBRARY_SUFFIX}" SET(CBLAS_LIBRARIES
"${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}"
CACHE FILEPATH "openblas library." FORCE) CACHE FILEPATH "openblas library." FORCE)
SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs) SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs)
......
...@@ -14,11 +14,41 @@ ...@@ -14,11 +14,41 @@
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
# Print and set the protobuf library information,
# finish this cmake process and exit from this file.
macro(PROMPT_PROTOBUF_LIB) macro(PROMPT_PROTOBUF_LIB)
SET(protobuf_DEPS ${ARGN})
MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}") MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}")
MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}") MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}")
MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}") MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}")
INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR})
# Assuming that all the protobuf libraries are of the same type.
IF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$")
SET(protobuf_LIBTYPE STATIC)
ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}$")
SET(protobuf_LIBTYPE SHARED)
ELSE()
MESSAGE(FATAL_ERROR "Unknown library type: ${PROTOBUF_LIBRARY}")
ENDIF()
ADD_LIBRARY(protobuf ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf PROPERTY IMPORTED_LOCATION ${PROTOBUF_LIBRARY})
ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY})
ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep})
ADD_DEPENDENCIES(protoc ${dep})
ENDFOREACH()
LIST(APPEND external_project_dependencies protobuf)
RETURN() RETURN()
endmacro() endmacro()
macro(SET_PROTOBUF_VERSION) macro(SET_PROTOBUF_VERSION)
...@@ -43,22 +73,23 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "") ...@@ -43,22 +73,23 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "")
endif() endif()
FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_NAME}) STRING(REPLACE "extern_" "" TARGET_DIR_NAME "${TARGET_NAME}")
SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_NAME}) SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_DIR_NAME})
SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_DIR_NAME})
SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE)
SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE)
SET(${TARGET_NAME}_LITE_LIBRARY SET(${TARGET_NAME}_LITE_LIBRARY
"${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${STATIC_LIBRARY_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${CMAKE_STATIC_LIBRARY_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(${TARGET_NAME}_LIBRARY SET(${TARGET_NAME}_LIBRARY
"${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${STATIC_LIBRARY_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(${TARGET_NAME}_PROTOC_LIBRARY SET(${TARGET_NAME}_PROTOC_LIBRARY
"${PROTOBUF_INSTALL_DIR}/lib/libprotoc${STATIC_LIBRARY_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/lib/libprotoc${CMAKE_STATIC_LIBRARY_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(${TARGET_NAME}_PROTOC_EXECUTABLE SET(${TARGET_NAME}_PROTOC_EXECUTABLE
"${PROTOBUF_INSTALL_DIR}/bin/protoc${EXECUTABLE_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(OPTIONAL_CACHE_ARGS "") SET(OPTIONAL_CACHE_ARGS "")
...@@ -109,6 +140,8 @@ IF(NOT CMAKE_CROSSCOMPILING) ...@@ -109,6 +140,8 @@ IF(NOT CMAKE_CROSSCOMPILING)
SET_PROTOBUF_VERSION() SET_PROTOBUF_VERSION()
IF("${PROTOBUF_VERSION}" VERSION_LESS "3.1.0") IF("${PROTOBUF_VERSION}" VERSION_LESS "3.1.0")
SET(PROTOBUF_FOUND OFF) SET(PROTOBUF_FOUND OFF)
ELSE()
PROMPT_PROTOBUF_LIB()
ENDIF() ENDIF()
ENDIF(PROTOBUF_FOUND) ENDIF(PROTOBUF_FOUND)
ELSE() ELSE()
...@@ -120,18 +153,22 @@ ELSE() ...@@ -120,18 +153,22 @@ ELSE()
ENDIF() ENDIF()
IF(NOT PROTOBUF_FOUND) IF(NOT PROTOBUF_FOUND)
build_protobuf(protobuf FALSE) build_protobuf(extern_protobuf FALSE)
LIST(APPEND external_project_dependencies protobuf)
SET(PROTOBUF_INCLUDE_DIR ${protobuf_INCLUDE_DIR} SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR}
CACHE PATH "protobuf include directory." FORCE) CACHE PATH "protobuf include directory." FORCE)
IF(NOT CMAKE_CROSSCOMPILING) SET(PROTOBUF_LITE_LIBRARY ${extern_protobuf_LITE_LIBRARY}
SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_PROTOC_EXECUTABLE} CACHE FILEPATH "protobuf lite library." FORCE)
SET(PROTOBUF_LIBRARY ${extern_protobuf_LIBRARY}
CACHE FILEPATH "protobuf library." FORCE)
SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY}
CACHE FILEPATH "protoc library." FORCE)
IF(CMAKE_CROSSCOMPILING)
PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf)
ELSE()
SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE) CACHE FILEPATH "protobuf executable." FORCE)
PROMPT_PROTOBUF_LIB(extern_protobuf)
ENDIF() ENDIF()
SET(PROTOBUF_LITE_LIBRARY ${protobuf_LITE_LIBRARY} CACHE FILEPATH "protobuf lite library." FORCE)
SET(PROTOBUF_LIBRARY ${protobuf_LIBRARY} CACHE FILEPATH "protobuf library." FORCE)
SET(PROTOBUF_PROTOC_LIBRARY ${protobuf_PROTOC_LIBRARY} CACHE FILEPATH "protoc library." FORCE)
ENDIF(NOT PROTOBUF_FOUND) ENDIF(NOT PROTOBUF_FOUND)
PROMPT_PROTOBUF_LIB()
\ No newline at end of file
...@@ -84,24 +84,6 @@ IF(DEFINED CMAKE_SYSTEM_NAME) ...@@ -84,24 +84,6 @@ IF(DEFINED CMAKE_SYSTEM_NAME)
ENDIF() ENDIF()
ENDIF() ENDIF()
# prefix and suffix on different os
IF(WIN32)
SET(LIBRARY_PREFIX "")
SET(SHARED_LIBRARY_SUFFIX ".dll")
SET(STATIC_LIBRARY_SUFFIX ".lib")
SET(EXECUTABLE_SUFFIX ".exe")
ELSE(WIN32)
SET(LIBRARY_PREFIX "lib")
IF(APPLE)
SET(SHARED_LIBRARY_SUFFIX ".dylib")
ELSE(APPLE)
SET(SHARED_LIBRARY_SUFFIX ".so")
ENDIF(APPLE)
SET(STATIC_LIBRARY_SUFFIX ".a")
SET(EXECUTABLE_SUFFIX "")
ENDIF(WIN32)
# external dependencies log output # external dependencies log output
SET(EXTERNAL_PROJECT_LOG_ARGS SET(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 0 # Wrap download in script to log output LOG_DOWNLOAD 0 # Wrap download in script to log output
......
...@@ -99,3 +99,12 @@ value_printer ...@@ -99,3 +99,12 @@ value_printer
.. automodule:: paddle.v2.evaluator .. automodule:: paddle.v2.evaluator
:members: value_printer :members: value_printer
:noindex: :noindex:
Detection
=====
detection_map
-------------
.. automodule:: paddle.v2.evaluator
:members: detection_map
:noindex:
package main package main
import ( import (
"fmt"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
) )
func main() { func main() {
port := flag.Int("port", 8080, "port of the master server.") port := flag.Int("port", 8080, "port of the master server.")
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse() flag.Parse()
if *faultTolerance { if *endpoints == "" {
panic("fault tolernance not implemented.") log.Warningln("-endpoints not set, fault tolerance not be enabled.")
}
var store master.Store
if *endpoints != "" {
eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP()
if err != nil {
log.Fatal(err)
}
addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil {
log.Fatal(err)
}
} else {
store = &master.InMemStore{}
}
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil {
log.Fatal(err)
} }
s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) err = rpc.Register(s)
err := rpc.Register(s)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
} }
...@@ -5,18 +5,35 @@ import ( ...@@ -5,18 +5,35 @@ import (
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
) )
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
s := pserver.NewService() level, err := log.ParseLevel(*logLevel)
err := rpc.Register(s) if err != nil {
panic(err)
}
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout)
if err != nil {
panic(err)
}
err = rpc.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -27,7 +44,9 @@ func main() { ...@@ -27,7 +44,9 @@ func main() {
panic(err) panic(err)
} }
log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -47,9 +47,13 @@ func TestGetFinishTask(t *testing.T) { ...@@ -47,9 +47,13 @@ func TestGetFinishTask(t *testing.T) {
} }
go func(l net.Listener) { go func(l net.Listener) {
s := NewService(chunkPerTask, time.Second, 1) s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
panic(err)
}
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err = server.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -33,9 +33,13 @@ func TestNextRecord(t *testing.T) { ...@@ -33,9 +33,13 @@ func TestNextRecord(t *testing.T) {
} }
go func(l net.Listener) { go func(l net.Listener) {
s := master.NewService(10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil {
panic(err)
}
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err = server.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
package master
import (
"context"
"time"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
const (
// DefaultLockPath is the default etcd master lock path.
DefaultLockPath = "/master/lock"
// DefaultStatePath is the default etcd key for master state.
DefaultStatePath = "/master/state"
// DefaultAddrPath is the default etcd key for master address.
DefaultAddrPath = "/master/addr"
)
// EtcdClient is the etcd client that master uses for fault tolerance
// and service registry.
type EtcdClient struct {
lockPath string
statePath string
client *clientv3.Client
lock *concurrency.Mutex
}
// NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints)
// TODO(helin): gracefully shutdown etcd store. Becuase etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: dialTimeout,
})
if err != nil {
return nil, err
}
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
if err != nil {
return nil, err
}
lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have
// one master running, but split-brain problem may cuase
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath)
err = lock.Lock(context.TODO())
if err != nil {
return nil, err
}
log.Debugf("Successfully acquired lock at %s.", lockPath)
put := clientv3.OpPut(addrPath, string(addr))
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Fatal("No longer owns the master lock. Exiting.")
}
e := &EtcdClient{
lockPath: lockPath,
statePath: statePath,
client: cli,
lock: lock,
}
return e, nil
}
// Save saves the state into the etcd.
func (e *EtcdClient) Save(state []byte) error {
ctx := context.TODO()
put := clientv3.OpPut(e.statePath, string(state))
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
if err != nil {
return err
}
if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx)
cancel()
if err != nil {
// We lost the master lock and can not acquire
// it back, it means some other master is
// already started. We don't want cluster
// managment system to kill the master server
// who is holding the lock and running
// correctly. So the most feasible solution is
// to kill current master server. The current
// state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry.
log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err)
}
log.Infof("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state)
}
return nil
}
// Load loads the state from etcd.
func (e *EtcdClient) Load() ([]byte, error) {
ctx := context.TODO()
get := clientv3.OpGet(e.statePath)
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background())
if err != nil {
return nil, err
}
return e.Load()
}
kvs := resp.Responses[0].GetResponseRange().Kvs
if len(kvs) == 0 {
// No state exists
return nil, nil
}
state := kvs[0].Value
return state, nil
}
package master
import "sync"
// InMemStore is an in memory implementation of Store interface.
//
// It does not tolerate the fault that casues the program to crash.
type InMemStore struct {
mu sync.Mutex
buf []byte
}
// Save saves the state into the in-memory store.
func (m *InMemStore) Save(state []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
m.buf = state
return nil
}
// Load loads the state from the in-memory store.
func (m *InMemStore) Load() ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.buf, nil
}
package master package master
import ( import (
"bytes"
"compress/gzip"
"encoding/gob"
"errors" "errors"
"os" "os"
"path/filepath" "path/filepath"
...@@ -12,24 +15,54 @@ import ( ...@@ -12,24 +15,54 @@ import (
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
const (
dialTimeout = 5 * time.Second
)
// Store is the interface for save and load the master state.
type Store interface {
Save([]byte) error
Load() ([]byte, error)
}
// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
Path string
Index recordio.Index // chunk index
}
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
ID int
Chunks []Chunk
}
type taskEntry struct {
Epoch int
NumTimeout int
Task Task
}
type taskQueues struct {
Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry
Failed []Task
}
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
timeoutMax int timeoutMax int
ready chan struct{} ready chan struct{}
store Store
mu sync.Mutex mu sync.Mutex
initDone bool initDone bool
taskQueues taskQueues taskQueues taskQueues
} }
// Recover recovers service state from etcd.
func Recover() (*Service, error) {
// TODO(helin): recover from snapshot state from etcd.
return nil, nil
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0 id := 0
if chunksPerTask <= 0 { if chunksPerTask <= 0 {
...@@ -58,7 +91,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -58,7 +91,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) {
s := &Service{} s := &Service{}
s.chunksPerTask = chunksPerTask s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
...@@ -66,38 +99,82 @@ func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Se ...@@ -66,38 +99,82 @@ func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Se
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{}) s.ready = make(chan struct{})
return s s.store = store
} recovered, err := s.recover()
if err != nil {
return nil, err
}
// Chunk is a chunk of data consisted of several data instances. if recovered {
type Chunk struct { // Recovered. Now the state is already initialized,
Path string // and the master is ready.
Index recordio.Index // chunk index s.initDone = true
} close(s.ready)
log.Info("Master recovered from saved state.")
}
// Task is the basic unit of data instances assigned to trainers. return s, nil
type Task struct {
ID int
Chunks []Chunk
} }
type taskEntry struct { // recover recovers service state from etcd.
Epoch int func (s *Service) recover() (bool, error) {
NumTimeout int state, err := s.store.Load()
Task Task if err != nil {
} return false, err
}
type taskQueues struct { if state == nil {
Todo []taskEntry log.Infoln("No state exists, not recovered.")
Pending map[int]taskEntry // map from task ID to task entry return false, nil
Done []taskEntry }
Failed []Task
log.Infof("Loaded snapshot of size: %d bytes.", len(state))
gr, err := gzip.NewReader(bytes.NewReader(state))
if err != nil {
return false, err
}
dec := gob.NewDecoder(gr)
var tqs taskQueues
err = dec.Decode(&tqs)
if err != nil {
return false, err
}
err = gr.Close()
if err != nil {
// Only close failed, recover actually succeed, so
// just log error.
log.Errorln(err)
}
s.taskQueues = tqs
return true, nil
} }
// *must* be called with s.mu being held. // snapshot *must* be called with s.mu being held.
func (s *Service) snapshot() error { func (s *Service) snapshot() error {
// TODO(helin): snapshot state on etcd. // TOOD(helin): etcd request has a size limit, so the snapshot
return nil // size is limited by the max request size. We should either
// divide the snapshot into smaller chunks and save under
// different keys, or configure the request size to be big
// enough:
// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
enc := gob.NewEncoder(gw)
err := enc.Encode(s.taskQueues)
if err != nil {
return err
}
err = gw.Close()
if err != nil {
return err
}
state := buf.Bytes()
log.Infof("Saving snapshot of size: %d bytes.", len(state))
return s.store.Save(state)
} }
func readChunks(globPaths []string) ([]Chunk, error) { func readChunks(globPaths []string) ([]Chunk, error) {
...@@ -207,12 +284,12 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -207,12 +284,12 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
t.NumTimeout++ t.NumTimeout++
if t.NumTimeout > s.timeoutMax { if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v timed out %d times, discard.\n", t.Task, t.NumTimeout) log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return return
} }
log.Warningf("Task %v timed out %d times, retry.\n", t.Task, t.NumTimeout) log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.taskQueues.Todo = append(s.taskQueues.Todo, t)
} }
} }
......
...@@ -133,7 +133,7 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, ...@@ -133,7 +133,7 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Errorln(err)
...@@ -200,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -200,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -210,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -210,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
...@@ -30,9 +31,12 @@ func init() { ...@@ -30,9 +31,12 @@ func init() {
port[i] = p port[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
panic(err)
}
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err = server.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
package pserver package pserver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -24,6 +33,9 @@ const ( ...@@ -24,6 +33,9 @@ const (
Float64 Float64
) )
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server. // Parameter is a piece of data to sync with the parameter server.
type Parameter struct { type Parameter struct {
Name string Name string
...@@ -47,14 +59,128 @@ type Service struct { ...@@ -47,14 +59,128 @@ type Service struct {
mu sync.Mutex mu sync.Mutex
opt *optimizer opt *optimizer
paramMap map[string]Parameter paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
} }
// NewService creates a new service. // NewService creates a new service, will bypass etcd registration if no
func NewService() *Service { // endpoints specified.
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)} s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
return s s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
} }
// InitParam initializes a parameter. // InitParam initializes a parameter.
......
...@@ -10,12 +10,15 @@ import ( ...@@ -10,12 +10,15 @@ import (
) )
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
var p pserver.Parameter var p pserver.Parameter
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -72,8 +75,11 @@ func TestFull(t *testing.T) { ...@@ -72,8 +75,11 @@ func TestFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
err := s.FinishInitParams(0, nil) if err != nil {
t.Error(err)
}
err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) { ...@@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
err := s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
} }
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
errCh := make(chan error, 2) errCh := make(chan error, 2)
var wg sync.WaitGroup var wg sync.WaitGroup
...@@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
......
package networkhelper
import (
"errors"
"net"
)
// GetExternalIP returns the ip address of local network interface, not the
// loopback device.
func GetExternalIP() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}
package networkhelper
import "testing"
func TestGetIP(t *testing.T) {
_, err := GetExternalIP()
if err != nil {
t.Errorf("GetExternalIP returns error : %v\n", err)
}
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Evaluator.h"
#include "paddle/gserver/layers/DetectionUtil.h"
using std::map;
using std::vector;
using std::pair;
using std::make_pair;
namespace paddle {
/**
* @brief detection map Evaluator
*
* The config file api is detection_map_evaluator.
*/
class DetectionMAPEvaluator : public Evaluator {
public:
DetectionMAPEvaluator()
: evaluateDifficult_(false), cpuOutput_(nullptr), cpuLabel_(nullptr) {}
virtual void start() {
Evaluator::start();
allTruePos_.clear();
allFalsePos_.clear();
numPos_.clear();
}
virtual real evalImp(std::vector<Argument>& arguments) {
overlapThreshold_ = config_.overlap_threshold();
backgroundId_ = config_.background_id();
evaluateDifficult_ = config_.evaluate_difficult();
apType_ = config_.ap_type();
MatrixPtr detectTmpValue = arguments[0].value;
Matrix::resizeOrCreate(cpuOutput_,
detectTmpValue->getHeight(),
detectTmpValue->getWidth(),
false,
false);
MatrixPtr labelTmpValue = arguments[1].value;
Matrix::resizeOrCreate(cpuLabel_,
labelTmpValue->getHeight(),
labelTmpValue->getWidth(),
false,
false);
cpuOutput_->copyFrom(*detectTmpValue);
cpuLabel_->copyFrom(*labelTmpValue);
Argument label = arguments[1];
const int* labelIndex = label.sequenceStartPositions->getData(false);
size_t batchSize = label.getNumSequences();
vector<map<size_t, vector<NormalizedBBox>>> allGTBBoxes;
vector<map<size_t, vector<pair<real, NormalizedBBox>>>> allDetectBBoxes;
for (size_t n = 0; n < batchSize; ++n) {
map<size_t, vector<NormalizedBBox>> bboxes;
for (int i = labelIndex[n]; i < labelIndex[n + 1]; ++i) {
vector<NormalizedBBox> bbox;
getBBoxFromLabelData(cpuLabel_->getData() + i * 6, 1, bbox);
int c = cpuLabel_->getData()[i * 6];
bboxes[c].push_back(bbox[0]);
}
allGTBBoxes.push_back(bboxes);
}
size_t n = 0;
const real* cpuOutputData = cpuOutput_->getData();
for (size_t imgId = 0; imgId < batchSize; ++imgId) {
map<size_t, vector<pair<real, NormalizedBBox>>> bboxes;
size_t curImgId = static_cast<size_t>((cpuOutputData + n * 7)[0]);
while (curImgId == imgId && n < cpuOutput_->getHeight()) {
vector<real> label;
vector<real> score;
vector<NormalizedBBox> bbox;
getBBoxFromDetectData(cpuOutputData + n * 7, 1, label, score, bbox);
bboxes[label[0]].push_back(make_pair(score[0], bbox[0]));
++n;
curImgId = static_cast<size_t>((cpuOutputData + n * 7)[0]);
}
allDetectBBoxes.push_back(bboxes);
}
for (size_t n = 0; n < batchSize; ++n) {
for (map<size_t, vector<NormalizedBBox>>::iterator it =
allGTBBoxes[n].begin();
it != allGTBBoxes[n].end();
++it) {
size_t count = 0;
if (evaluateDifficult_) {
count = it->second.size();
} else {
for (size_t i = 0; i < it->second.size(); ++i)
if (!(it->second[i].isDifficult)) ++count;
}
if (numPos_.find(it->first) == numPos_.end() && count != 0) {
numPos_[it->first] = count;
} else {
numPos_[it->first] += count;
}
}
}
// calcTFPos
calcTFPos(batchSize, allGTBBoxes, allDetectBBoxes);
return 0;
}
virtual void printStats(std::ostream& os) const {
real mAP = calcMAP();
os << "Detection mAP=" << mAP;
}
virtual void distributeEval(ParameterClient2* client) {
LOG(FATAL) << "Distribute detection evaluation not implemented.";
}
protected:
void calcTFPos(const size_t batchSize,
const vector<map<size_t, vector<NormalizedBBox>>>& allGTBBoxes,
const vector<map<size_t, vector<pair<real, NormalizedBBox>>>>&
allDetectBBoxes) {
for (size_t n = 0; n < allDetectBBoxes.size(); ++n) {
if (allGTBBoxes[n].size() == 0) {
for (map<size_t, vector<pair<real, NormalizedBBox>>>::const_iterator
it = allDetectBBoxes[n].begin();
it != allDetectBBoxes[n].end();
++it) {
size_t label = it->first;
for (size_t i = 0; i < it->second.size(); ++i) {
allTruePos_[label].push_back(make_pair(it->second[i].first, 0));
allFalsePos_[label].push_back(make_pair(it->second[i].first, 1));
}
}
} else {
for (map<size_t, vector<pair<real, NormalizedBBox>>>::const_iterator
it = allDetectBBoxes[n].begin();
it != allDetectBBoxes[n].end();
++it) {
size_t label = it->first;
vector<pair<real, NormalizedBBox>> predBBoxes = it->second;
if (allGTBBoxes[n].find(label) == allGTBBoxes[n].end()) {
for (size_t i = 0; i < predBBoxes.size(); ++i) {
allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0));
allFalsePos_[label].push_back(make_pair(predBBoxes[i].first, 1));
}
} else {
vector<NormalizedBBox> gtBBoxes =
allGTBBoxes[n].find(label)->second;
vector<bool> visited(gtBBoxes.size(), false);
// Sort detections in descend order based on scores
std::sort(predBBoxes.begin(),
predBBoxes.end(),
sortScorePairDescend<NormalizedBBox>);
for (size_t i = 0; i < predBBoxes.size(); ++i) {
real maxOverlap = -1.0;
size_t maxIdx = 0;
for (size_t j = 0; j < gtBBoxes.size(); ++j) {
real overlap =
jaccardOverlap(predBBoxes[i].second, gtBBoxes[j]);
if (overlap > maxOverlap) {
maxOverlap = overlap;
maxIdx = j;
}
}
if (maxOverlap > overlapThreshold_) {
if (evaluateDifficult_ ||
(!evaluateDifficult_ && !gtBBoxes[maxIdx].isDifficult)) {
if (!visited[maxIdx]) {
allTruePos_[label].push_back(
make_pair(predBBoxes[i].first, 1));
allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 0));
visited[maxIdx] = true;
} else {
allTruePos_[label].push_back(
make_pair(predBBoxes[i].first, 0));
allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 1));
}
}
} else {
allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0));
allFalsePos_[label].push_back(
make_pair(predBBoxes[i].first, 1));
}
}
}
}
}
}
}
real calcMAP() const {
real mAP = 0.0;
size_t count = 0;
for (map<size_t, size_t>::const_iterator it = numPos_.begin();
it != numPos_.end();
++it) {
size_t label = it->first;
size_t labelNumPos = it->second;
if (labelNumPos == 0 || allTruePos_.find(label) == allTruePos_.end())
continue;
vector<pair<real, size_t>> labelTruePos = allTruePos_.find(label)->second;
vector<pair<real, size_t>> labelFalsePos =
allFalsePos_.find(label)->second;
// Compute average precision.
vector<size_t> tpCumSum;
getAccumulation(labelTruePos, &tpCumSum);
vector<size_t> fpCumSum;
getAccumulation(labelFalsePos, &fpCumSum);
std::vector<real> precision, recall;
size_t num = tpCumSum.size();
// Compute Precision.
for (size_t i = 0; i < num; ++i) {
CHECK_LE(tpCumSum[i], labelNumPos);
precision.push_back(static_cast<real>(tpCumSum[i]) /
static_cast<real>(tpCumSum[i] + fpCumSum[i]));
recall.push_back(static_cast<real>(tpCumSum[i]) / labelNumPos);
}
// VOC2007 style
if (apType_ == "11point") {
vector<real> maxPrecisions(11, 0.0);
int startIdx = num - 1;
for (int j = 10; j >= 0; --j)
for (int i = startIdx; i >= 0; --i) {
if (recall[i] < j / 10.) {
startIdx = i;
if (j > 0) maxPrecisions[j - 1] = maxPrecisions[j];
break;
} else {
if (maxPrecisions[j] < precision[i])
maxPrecisions[j] = precision[i];
}
}
for (int j = 10; j >= 0; --j) mAP += maxPrecisions[j] / 11;
++count;
} else if (apType_ == "Integral") {
// Nature integral
real averagePrecisions = 0.;
real prevRecall = 0.;
for (size_t i = 0; i < num; ++i) {
if (fabs(recall[i] - prevRecall) > 1e-6)
averagePrecisions += precision[i] * fabs(recall[i] - prevRecall);
prevRecall = recall[i];
}
mAP += averagePrecisions;
++count;
} else {
LOG(FATAL) << "Unkown ap version: " << apType_;
}
}
if (count != 0) mAP /= count;
return mAP * 100;
}
void getAccumulation(vector<pair<real, size_t>> inPairs,
vector<size_t>* accuVec) const {
std::stable_sort(
inPairs.begin(), inPairs.end(), sortScorePairDescend<size_t>);
accuVec->clear();
size_t sum = 0;
for (size_t i = 0; i < inPairs.size(); ++i) {
sum += inPairs[i].second;
accuVec->push_back(sum);
}
}
std::string getTypeImpl() const { return "detection_map"; }
real getValueImpl() const { return calcMAP(); }
private:
real overlapThreshold_; // overlap threshold when determining whether matched
bool evaluateDifficult_; // whether evaluate difficult ground truth
size_t backgroundId_; // class index of background
std::string apType_; // how to calculate mAP (Integral or 11point)
MatrixPtr cpuOutput_;
MatrixPtr cpuLabel_;
map<size_t, size_t> numPos_; // counts of true objects each classification
map<size_t, vector<pair<real, size_t>>>
allTruePos_; // true positive prediction
map<size_t, vector<pair<real, size_t>>>
allFalsePos_; // false positive prediction
};
REGISTER_EVALUATOR(detection_map, DetectionMAPEvaluator);
} // namespace paddle
...@@ -309,35 +309,35 @@ public: ...@@ -309,35 +309,35 @@ public:
void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) { void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) {
evaluators_.emplace_back(std::move(evaluator)); evaluators_.emplace_back(std::move(evaluator));
} }
virtual void start() { void start() override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->start(); evaluator->start();
} }
} }
virtual void finish() { void finish() override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->finish(); evaluator->finish();
} }
} }
virtual void eval(const NeuralNetwork& nn) override { void eval(const NeuralNetwork& nn) override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->eval(nn); evaluator->eval(nn);
} }
} }
virtual real evalImp(std::vector<Argument>& arguments) { real evalImp(std::vector<Argument>& arguments) override {
(void)arguments; (void)arguments;
return -1; return -1;
} }
virtual void printStats(std::ostream& os) const { void printStats(std::ostream& os) const override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->printStats(os); evaluator->printStats(os);
os << ' '; os << ' ';
} }
} }
virtual void distributeEval(ParameterClient2* client) { void distributeEval(ParameterClient2* client) override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->distributeEval(client); evaluator->distributeEval(client);
} }
...@@ -352,7 +352,7 @@ public: ...@@ -352,7 +352,7 @@ public:
* @brief getNames will return all inside evaluators' names. * @brief getNames will return all inside evaluators' names.
* @param names [out]: return names. * @param names [out]: return names.
*/ */
void getNames(std::vector<std::string>* names) { void getNames(std::vector<std::string>* names) override {
for (auto& eval : evaluators_) { for (auto& eval : evaluators_) {
eval->getNames(names); eval->getNames(names);
} }
...@@ -361,7 +361,7 @@ public: ...@@ -361,7 +361,7 @@ public:
/** /**
* @brief getValue could get all inside evaluators' value. * @brief getValue could get all inside evaluators' value.
*/ */
real getValue(const std::string& name, Error* err) const { real getValue(const std::string& name, Error* err) const override {
return this->getMethodHelper<real>( return this->getMethodHelper<real>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) { name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValue(name, err); return eval->getValue(name, err);
...@@ -371,7 +371,7 @@ public: ...@@ -371,7 +371,7 @@ public:
/** /**
* @brief getType could get all inside evaluators' type. * @brief getType could get all inside evaluators' type.
*/ */
std::string getType(const std::string& name, Error* err) const { std::string getType(const std::string& name, Error* err) const override {
return this->getMethodHelper<std::string>( return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) { name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getType(name, err); return eval->getType(name, err);
......
...@@ -138,6 +138,23 @@ void testEvaluatorAll(TestConfig testConf, ...@@ -138,6 +138,23 @@ void testEvaluatorAll(TestConfig testConf,
testEvaluator(testConf, testEvaluatorName, batchSize, false); testEvaluator(testConf, testEvaluatorName, batchSize, false);
} }
TEST(Evaluator, detection_map) {
TestConfig config;
config.evaluatorConfig.set_type("detection_map");
config.evaluatorConfig.set_overlap_threshold(0.5);
config.evaluatorConfig.set_background_id(0);
config.evaluatorConfig.set_ap_type("Integral");
config.evaluatorConfig.set_evaluate_difficult(0);
config.inputDefs.push_back({INPUT_DATA, "output", 7});
config.inputDefs.push_back({INPUT_SEQUENCE_DATA, "label", 6});
config.evaluatorConfig.set_evaluate_difficult(false);
testEvaluatorAll(config, "detection_map", 100);
config.evaluatorConfig.set_evaluate_difficult(true);
testEvaluatorAll(config, "detection_map", 100);
}
TEST(Evaluator, classification_error) { TEST(Evaluator, classification_error) {
TestConfig config; TestConfig config;
config.evaluatorConfig.set_type("classification_error"); config.evaluatorConfig.set_type("classification_error");
......
...@@ -14,11 +14,13 @@ limitations under the License. */ ...@@ -14,11 +14,13 @@ limitations under the License. */
#include "ParameterUpdaterHook.h" #include "ParameterUpdaterHook.h"
#include <algorithm>
#include <atomic> #include <atomic>
#include <fstream> #include <fstream>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h" #include "paddle/parameter/Parameter.h"
...@@ -29,106 +31,76 @@ namespace paddle { ...@@ -29,106 +31,76 @@ namespace paddle {
/** /**
* The static pruning hook * The static pruning hook
* * Static means user specify a sparsity_ratio before training started, and the
* Static means user load a mask map before training started. This map will * network will prune the parameters based on the sparsity_ratio. More details
* define which link/weight between neural is disabled. * can be found https://arxiv.org/pdf/1506.02626.pdf.
*/ */
class StaticPruningHook : public IParameterUpdaterHook { class StaticPruningHook : public IParameterUpdaterHook {
public: public:
/** explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
* The Mask Map Header. : initCount_(0) {
* The map file started with this header. sparsityRatio_ = hookConfig.sparsity_ratio();
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct StaticMaskHeader {
uint32_t version;
size_t size;
} __attribute__((__packed__));
explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) {
bool ok = this->loadMaskFile(mask_filename);
if (!ok) {
LOG(WARNING) << "Fail to load mask file " << mask_filename
<< " in current directory, searching in init_model_path";
std::string combineMaskFilename =
path::join(FLAGS_init_model_path, mask_filename);
CHECK(this->loadMaskFile(combineMaskFilename))
<< "Cannot load " << mask_filename << " in ./" << mask_filename
<< " and " << combineMaskFilename;
} }
VLOG(3) << mask_filename << " mask size = " << this->mask_.size();
static bool sortPairAscend(const std::pair<real, size_t> &pair1,
const std::pair<real, size_t> &pair2) {
return pair1.first > pair2.first;
} }
void update(Parameter* para) { void update(Parameter *para) {
updateThreadChecker_.check(); updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT); auto &vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) { if (vec) {
vec->dotMul(*maskVec_); vec->dotMul(*maskVec_);
} }
} }
void init(Parameter* para) { void generateMask(Parameter *para) {
size_t initCount = this->initCount_.fetch_add(1); VectorPtr maskTemp = Vector::create(para->getSize(), false);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " maskTemp->zeroMem();
"in same ParamterUpdater"; real *maskTempData = maskTemp->getData();
VLOG(3) << "Initialize Parameter " << para; size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
SetDevice device(para->getDeviceId());
auto maskVec = Vector::create(this->mask_.size(), false); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
{ // Initialize maskVec with float mask vector VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
real* dataPtr = maskVec->getData();
size_t i = 0; paraCpuCopy->copyFrom(*paraVec);
for (bool m : mask_) { std::vector<std::pair<real, size_t>> param;
dataPtr[i++] = m ? 1.0 : 0.0;
} for (size_t i = 0; i < para->getSize(); i++)
} param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack. // Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) { if (para->useGpu()) {
maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskVec); maskVec_->copyFrom(*maskTemp);
} else { } else {
maskVec_ = maskVec; maskVec_ = maskTemp;
} }
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
} }
private: void init(Parameter *para) {
bool loadMaskFile(const std::string& mask_filename) { generateMask(para);
std::ifstream fin; size_t initCount = this->initCount_.fetch_add(1);
fin.open(mask_filename); CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
if (fin.is_open()) { "in same ParamterUpdater";
StaticMaskHeader header; VLOG(3) << "Initialize Parameter " << para;
fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader)); SetDevice device(para->getDeviceId());
CHECK_EQ(header.version, 0UL);
mask_.resize(header.size); auto &paraVec = para->getBuf(PARAMETER_VALUE);
uint8_t buf; paraVec->dotMul(*maskVec_);
for (size_t i = 0; i < header.size; ++i, buf <<= 1) {
if (i % 8 == 0) {
fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t));
}
mask_[i] = buf & 0x80;
}
fin.close();
return true;
} else {
return false;
}
} }
private:
SameThreadChecker updateThreadChecker_; SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_; std::atomic<size_t> initCount_;
VectorPtr maskVec_; VectorPtr maskVec_;
std::vector<bool> mask_; real sparsityRatio_;
}; };
IParameterUpdaterHook::IParameterUpdaterHook() {} IParameterUpdaterHook::IParameterUpdaterHook() {}
...@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {} ...@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/ */
class StringIntPairHasher { class StringIntPairHasher {
public: public:
size_t operator()(const std::pair<std::string, int>& k) const { size_t operator()(const std::pair<std::string, int> &k) const {
return intHasher_(strHasher_(k.first) + k.second); return intHasher_(strHasher_(k.first) + k.second);
} }
...@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>, ...@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
/** /**
* ParameterUpdaterHook actually factory method. * ParameterUpdaterHook actually factory method.
*/ */
static IParameterUpdaterHook* createImpl( static IParameterUpdaterHook *createImpl(
const ParameterUpdaterHookConfig& config) { const ParameterUpdaterHookConfig &config) {
auto& type = config.type(); auto &type = config.type();
if (type == "pruning") { if (type == "pruning") {
if (config.has_purning_mask_filename()) { return new StaticPruningHook(config);
return new StaticPruningHook(config.purning_mask_filename());
}
} }
LOG(FATAL) << "Unknown Hook type: " << type;
return nullptr; return nullptr;
} }
std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create( std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create(
const ParameterConfig& paramConfig, int idx) { const ParameterConfig &paramConfig, int idx) {
std::pair<std::string, int> key = {paramConfig.name(), idx}; std::pair<std::string, int> key = {paramConfig.name(), idx};
return g_hookCache_.get( return g_hookCache_.get(
key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); key, [&] { return createImpl(paramConfig.update_hooks(idx)); });
......
#!/bin/bash
source ./common.sh
NPROC=1
export PYTHONPATH=/opt/python/2.7.12/lib/python2.7/site-packages
export PYTHONHOME=/opt/python/2.7.12
export PATH=/opt/python/2.7.12/bin:${PATH}
cmake .. -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DON_TRAVIS=ON -DWITH_COVERAGE=ON -DCOVERALLS_UPLOAD=ON ${EXTRA_CMAKE_OPTS}
NRPOC=`nproc`
make -j $NPROC
make coveralls
sudo make install
#!/bin/bash #!/bin/bash
set -e
# Create the build directory for CMake.
mkdir -p $TRAVIS_BUILD_DIR/build
cd $TRAVIS_BUILD_DIR/build
# Add set -e, cd to directory.
source ./common.sh
# Compile Documentation only. # Compile Documentation only.
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF ${EXTRA_CMAKE_OPTS} cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF
mkdir output mkdir output
make -j `nproc` make -j `nproc`
find .. -name '*whl' | xargs pip install # install all wheels. find .. -name '*whl' | xargs pip install # install all wheels.
rm -rf * rm -rf *
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=ON ${EXTRA_CMAKE_OPTS} cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON
make paddle_docs paddle_docs_cn make -j `nproc` paddle_docs paddle_docs_cn
# check websites for broken links # check websites for broken links
linkchecker doc/en/html/index.html linkchecker doc/en/html/index.html
......
#!/bin/bash #!/bin/bash
function abort(){ function abort(){
echo "Your commit not fit PaddlePaddle code style" 1>&2 echo "Your change doesn't follow PaddlePaddle's code style." 1>&2
echo "Please use pre-commit scripts to auto-format your code" 1>&2 echo "Please use pre-commit to reformat your code and git push again." 1>&2
exit 1 exit 1
} }
trap 'abort' 0 trap 'abort' 0
set -e set -e
source common.sh
cd .. cd $TRAVIS_BUILD_DIR
export PATH=/usr/bin:$PATH export PATH=/usr/bin:$PATH
pre-commit install pre-commit install
clang-format --version clang-format --version
......
#!/bin/bash
set -e
mkdir -p ../../../build
cd ../../../build
mkdir -p $HOME/third_party
EXTRA_CMAKE_OPTS="-DTHIRD_PARTY_PATH=${HOME}/third_party"
#!/bin/bash
cd `dirname $0`
if [ ${JOB} == "BUILD_AND_TEST" ]; then
./build_and_test.sh
elif [ ${JOB} == "DOCS" ]; then
./docs.sh
elif [ ${JOB} == "PRE_COMMIT" ]; then
./precommit.sh
else
echo Unknown job ${JOB}
exit 1
fi
...@@ -489,6 +489,15 @@ message EvaluatorConfig { ...@@ -489,6 +489,15 @@ message EvaluatorConfig {
// Used by ClassificationErrorEvaluator // Used by ClassificationErrorEvaluator
// top # classification error // top # classification error
optional int32 top_k = 13 [default = 1]; optional int32 top_k = 13 [default = 1];
// Used by DetectionMAPEvaluator
optional double overlap_threshold = 14 [default = 0.5];
optional int32 background_id = 15 [default = 0];
optional bool evaluate_difficult = 16 [default = false];
optional string ap_type = 17 [default = "11point"];
} }
message LinkConfig { message LinkConfig {
......
...@@ -25,8 +25,10 @@ enum ParameterInitStrategy { ...@@ -25,8 +25,10 @@ enum ParameterInitStrategy {
} }
message ParameterUpdaterHookConfig { message ParameterUpdaterHookConfig {
// hook type such as 'pruning'
required string type = 1; required string type = 1;
optional string purning_mask_filename = 2; // this represents the ratio of zero element to be set by the Parameter
optional double sparsity_ratio = 2 [default = 0.6];
} }
message ParameterConfig { message ParameterConfig {
......
...@@ -1280,8 +1280,7 @@ def parse_maxout(maxout, input_layer_name, maxout_conf): ...@@ -1280,8 +1280,7 @@ def parse_maxout(maxout, input_layer_name, maxout_conf):
# Define an evaluator # Define an evaluator
@config_func @config_func
def Evaluator( def Evaluator(name,
name,
type, type,
inputs, inputs,
chunk_scheme=None, chunk_scheme=None,
...@@ -1293,7 +1292,11 @@ def Evaluator( ...@@ -1293,7 +1292,11 @@ def Evaluator(
num_results=None, num_results=None,
top_k=None, top_k=None,
delimited=None, delimited=None,
excluded_chunk_types=None, ): excluded_chunk_types=None,
overlap_threshold=None,
background_id=None,
evaluate_difficult=None,
ap_type=None):
evaluator = g_config.model_config.evaluators.add() evaluator = g_config.model_config.evaluators.add()
evaluator.type = type evaluator.type = type
evaluator.name = MakeLayerNameInSubmodel(name) evaluator.name = MakeLayerNameInSubmodel(name)
...@@ -1327,6 +1330,18 @@ def Evaluator( ...@@ -1327,6 +1330,18 @@ def Evaluator(
if excluded_chunk_types: if excluded_chunk_types:
evaluator.excluded_chunk_types.extend(excluded_chunk_types) evaluator.excluded_chunk_types.extend(excluded_chunk_types)
if overlap_threshold is not None:
evaluator.overlap_threshold = overlap_threshold
if background_id is not None:
evaluator.background_id = background_id
if evaluate_difficult is not None:
evaluator.evaluate_difficult = evaluate_difficult
if ap_type is not None:
evaluator.ap_type = ap_type
class LayerBase(object): class LayerBase(object):
def __init__( def __init__(
...@@ -3124,11 +3139,11 @@ def Layer(name, type, **xargs): ...@@ -3124,11 +3139,11 @@ def Layer(name, type, **xargs):
@config_func @config_func
def ParameterHook(type, **kwargs): def ParameterHook(type, **kwargs):
if type == 'pruning': if type == 'pruning':
mask_filename = kwargs.get('mask_filename', None)
assert mask_filename is not None
hook = ParameterUpdaterHookConfig() hook = ParameterUpdaterHookConfig()
hook.type = type hook.type = type
hook.purning_mask_filename = mask_filename sparsity_ratio = kwargs.get('sparsity_ratio', None)
if sparsity_ratio is not None:
hook.sparsity_ratio = sparsity_ratio
return hook return hook
else: else:
return None return None
...@@ -3236,13 +3251,13 @@ def Parameter(name, ...@@ -3236,13 +3251,13 @@ def Parameter(name,
if update_hooks is not None: if update_hooks is not None:
if hasattr(update_hooks, '__call__'): if hasattr(update_hooks, '__call__'):
update_hooks = update_hooks(para.name) update_hooks = update_hooks()
if isinstance(update_hooks, list): if isinstance(update_hooks, list):
for hook in update_hooks: for hook in update_hooks:
para.update_hooks.extend([hook]) para.update_hooks.extend([hook])
else: else:
para.update_hooks.extend(update_hooks) para.update_hooks.extend([update_hooks])
g_parameter_map[name] = para g_parameter_map[name] = para
if initializer is not None: if initializer is not None:
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
from paddle.trainer.config_parser import * from paddle.trainer.config_parser import *
__all__ = [ __all__ = [
'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute' 'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute',
'ExtraLayerAttribute'
] ]
...@@ -55,6 +56,40 @@ def is_compatible_with(x, Type): ...@@ -55,6 +56,40 @@ def is_compatible_with(x, Type):
return False return False
class HookAttribute(object):
"""
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
def __init__(self, type, sparsity_ratio=None):
self.type = type
self.sparsity_ratio = sparsity_ratio
if self.sparsity_ratio is not None:
assert is_compatible_with(
self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparsity_ratio must be a float between [0, 1] '
def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
class ParameterAttribute(object): class ParameterAttribute(object):
""" """
Parameter Attributes object. To fine-tuning network training process, user Parameter Attributes object. To fine-tuning network training process, user
...@@ -114,6 +149,7 @@ class ParameterAttribute(object): ...@@ -114,6 +149,7 @@ class ParameterAttribute(object):
momentum=None, momentum=None,
gradient_clipping_threshold=None, gradient_clipping_threshold=None,
sparse_update=False, sparse_update=False,
update_hooks=None,
initializer=None): initializer=None):
self.attr = {} self.attr = {}
...@@ -169,6 +205,9 @@ class ParameterAttribute(object): ...@@ -169,6 +205,9 @@ class ParameterAttribute(object):
if initializer is not None: if initializer is not None:
self.attr['initializer'] = initializer self.attr['initializer'] = initializer
if update_hooks:
self.attr['update_hooks'] = update_hooks
def set_default_parameter_name(self, name): def set_default_parameter_name(self, name):
""" """
Set default parameter name. If parameter not set, then will use default Set default parameter name. If parameter not set, then will use default
...@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object): ...@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object):
return attr.attr return attr.attr
HookAttr = HookAttribute
ParamAttr = ParameterAttribute ParamAttr = ParameterAttribute
ExtraAttr = ExtraLayerAttribute ExtraAttr = ExtraLayerAttribute
...@@ -21,7 +21,8 @@ __all__ = [ ...@@ -21,7 +21,8 @@ __all__ = [
"chunk_evaluator", "sum_evaluator", "column_sum_evaluator", "chunk_evaluator", "sum_evaluator", "column_sum_evaluator",
"value_printer_evaluator", "gradient_printer_evaluator", "value_printer_evaluator", "gradient_printer_evaluator",
"maxid_printer_evaluator", "maxframe_printer_evaluator", "maxid_printer_evaluator", "maxframe_printer_evaluator",
"seqtext_printer_evaluator", "classification_error_printer_evaluator" "seqtext_printer_evaluator", "classification_error_printer_evaluator",
"detection_map_evaluator"
] ]
...@@ -31,10 +32,11 @@ class EvaluatorAttribute(object): ...@@ -31,10 +32,11 @@ class EvaluatorAttribute(object):
FOR_RANK = 1 << 2 FOR_RANK = 1 << 2
FOR_PRINT = 1 << 3 FOR_PRINT = 1 << 3
FOR_UTILS = 1 << 4 FOR_UTILS = 1 << 4
FOR_DETECTION = 1 << 5
KEYS = [ KEYS = [
"for_classification", "for_regression", "for_rank", "for_print", "for_classification", "for_regression", "for_rank", "for_print",
"for_utils" "for_utils", "for_detection"
] ]
@staticmethod @staticmethod
...@@ -57,8 +59,7 @@ def evaluator(*attrs): ...@@ -57,8 +59,7 @@ def evaluator(*attrs):
return impl return impl
def evaluator_base( def evaluator_base(input,
input,
type, type,
label=None, label=None,
weight=None, weight=None,
...@@ -72,7 +73,11 @@ def evaluator_base( ...@@ -72,7 +73,11 @@ def evaluator_base(
num_results=None, num_results=None,
delimited=None, delimited=None,
top_k=None, top_k=None,
excluded_chunk_types=None, ): excluded_chunk_types=None,
overlap_threshold=None,
background_id=None,
evaluate_difficult=None,
ap_type=None):
""" """
Evaluator will evaluate the network status while training/testing. Evaluator will evaluate the network status while training/testing.
...@@ -107,6 +112,14 @@ def evaluator_base( ...@@ -107,6 +112,14 @@ def evaluator_base(
:type weight: LayerOutput. :type weight: LayerOutput.
:param top_k: number k in top-k error rate :param top_k: number k in top-k error rate
:type top_k: int :type top_k: int
:param overlap_threshold: In detection tasks to filter detection results
:type overlap_threshold: float
:param background_id: Identifier of background class
:type background_id: int
:param evaluate_difficult: Whether to evaluate difficult objects
:type evaluate_difficult: bool
:param ap_type: How to calculate average persicion
:type ap_type: str
""" """
# inputs type assertions. # inputs type assertions.
assert classification_threshold is None or isinstance( assert classification_threshold is None or isinstance(
...@@ -136,7 +149,61 @@ def evaluator_base( ...@@ -136,7 +149,61 @@ def evaluator_base(
delimited=delimited, delimited=delimited,
num_results=num_results, num_results=num_results,
top_k=top_k, top_k=top_k,
excluded_chunk_types=excluded_chunk_types, ) excluded_chunk_types=excluded_chunk_types,
overlap_threshold=overlap_threshold,
background_id=background_id,
evaluate_difficult=evaluate_difficult,
ap_type=ap_type)
@evaluator(EvaluatorAttribute.FOR_DETECTION)
@wrap_name_default()
def detection_map_evaluator(input,
label,
overlap_threshold=0.5,
background_id=0,
evaluate_difficult=False,
ap_type="11point",
name=None):
"""
Detection mAP Evaluator. It will print mean Average Precision (mAP) for detection.
The detection mAP Evaluator based on the output of detection_output layer counts
the true positive and the false positive bbox and integral them to get the
mAP.
The simple usage is:
.. code-block:: python
eval = detection_map_evaluator(input=det_output,label=lbl)
:param input: Input layer.
:type input: LayerOutput
:param label: Label layer.
:type label: LayerOutput
:param overlap_threshold: The bbox overlap threshold of a true positive.
:type overlap_threshold: float
:param background_id: The background class index.
:type background_id: int
:param evaluate_difficult: Whether evaluate a difficult ground truth.
:type evaluate_difficult: bool
"""
if not isinstance(input, list):
input = [input]
if label:
input.append(label)
evaluator_base(
name=name,
type="detection_map",
input=input,
label=label,
overlap_threshold=overlap_threshold,
background_id=background_id,
evaluate_difficult=evaluate_difficult,
ap_type=ap_type)
@evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION)
......
...@@ -3839,7 +3839,8 @@ def classification_cost(input, ...@@ -3839,7 +3839,8 @@ def classification_cost(input,
weight=None, weight=None,
name=None, name=None,
evaluator=classification_error_evaluator, evaluator=classification_error_evaluator,
layer_attr=None): layer_attr=None,
coeff=1.):
""" """
classification cost Layer. classification cost Layer.
...@@ -3855,6 +3856,8 @@ def classification_cost(input, ...@@ -3855,6 +3856,8 @@ def classification_cost(input,
:param evaluator: Evaluator method. :param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute. :param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute :type layer_attr: ExtraLayerAttribute
:param coeff: The coefficient affects the gradient in the backward.
:type coeff: float
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -3868,6 +3871,7 @@ def classification_cost(input, ...@@ -3868,6 +3871,7 @@ def classification_cost(input,
name=name, name=name,
type="multi-class-cross-entropy", type="multi-class-cross-entropy",
inputs=ipts, inputs=ipts,
coeff=coeff,
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
def __add_evaluator__(e): def __add_evaluator__(e):
......
...@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs ...@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__ = [ __all__ = [
"Param", "Param",
"Extra", "Extra",
"Hook",
] ]
Param = paddle.trainer_config_helpers.attrs.ParameterAttribute Param = paddle.trainer_config_helpers.attrs.ParameterAttribute
Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute
Hook = paddle.trainer_config_helpers.attrs.HookAttribute
for each in paddle.trainer_config_helpers.attrs.__all__: for each in paddle.trainer_config_helpers.attrs.__all__:
globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each) globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册