提交 40483c11 编写于 作者: Z zchen0211

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

INCLUDE(ExternalProject) include(ExternalProject)
SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl) set(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
include_directories(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO) if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies # If we use DSO, we do not build nccl, just download the dependencies
...@@ -12,39 +11,39 @@ if(WITH_DSO) ...@@ -12,39 +11,39 @@ if(WITH_DSO)
set(NCCL_INSTALL_DIR "") set(NCCL_INSTALL_DIR "")
else() else()
# otherwise, we build nccl and link it. # otherwise, we build nccl and link it.
set(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
# Note: cuda 8.0 is needed to make nccl
# When cuda is not installed on the system directory, need to set CUDA_HOME to your cuda root
set(NCCL_BUILD_COMMAND "make -j 8") set(NCCL_BUILD_COMMAND "make -j 8")
set(NCCL_INSTALL_COMMAND "make install") set(NCCL_INSTALL_COMMAND "make install PREFIX=${NCCL_INSTALL_DIR}")
SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
endif() endif()
ExternalProject_Add( ExternalProject_Add(
extern_nccl extern_nccl
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git" GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
GIT_TAG "v1.3.4-1" GIT_TAG "v1.3.4-1"
PREFIX "${NCCL_SOURCE_DIR}" PREFIX "${NCCL_SOURCE_DIR}"
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "${NCCL_BUILD_COMMAND}" BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}" INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
INSTALL_DIR "${NCCL_INSTALL_DIR}" INSTALL_DIR "${NCCL_INSTALL_DIR}"
TEST_COMMAND "" TEST_COMMAND ""
) )
if (WITH_DSO) if(WITH_DSO)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0") if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_nccl_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";") file(WRITE ${dummyfile} "const char * dummy_nccl = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile}) add_library(nccl STATIC ${dummyfile})
else() else()
add_library(nccl INTERFACE) add_library(nccl INTERFACE)
endif() endif()
else() else()
ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL) add_library(nccl STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION set_property(TARGET nccl PROPERTY IMPORTED_LOCATION
${NCCL_INSTALL_DIR}/lib/libnccl.a) ${NCCL_INSTALL_DIR}/lib/libnccl_static.a)
endif() endif()
add_dependencies(nccl extern_nccl) add_dependencies(nccl extern_nccl)
LIST(APPEND external_project_dependencies nccl)
...@@ -2,21 +2,21 @@ ...@@ -2,21 +2,21 @@
## Motivation ## Motivation
The model is the output of training process. One complete model consists of two parts, namely, the **topology** and the **parameters**. To support industrial deployment, we need to make the model format must be self-completed and do not expose any training source code. A model is an output of the training process. One complete model consists of two parts, the **topology** and the **parameters**. In order to support industrial deployment, the model format must be self-complete and must not expose any training source code.
As a result, In PaddlePaddle, the **topology** represents as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model, we must support large size parameter, and efficient serialization/deserialization. As a result, In PaddlePaddle, the **topology** is represented as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model. We must support large size parameters and efficient serialization/deserialization of parameters.
## Implementation ## Implementation
The topology is saved as a plain text, in detail, a self-contain protobuf file. The topology is saved as a plain text in a detailed self-contain protobuf file.
The parameters are saved as a binary file. As we all know, the protobuf message has the limits of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We do a (benchmark experiment)[https://github.com/PaddlePaddle/Paddle/pull/4610], its result shows protobuf is not fit in this scene. The parameters are saved as a binary file. As we all know, the protobuf message has a limit of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We have done a [benchmark experiment](https://github.com/PaddlePaddle/Paddle/pull/4610), which shows that protobuf is not fit for the task.
As a result, we design a particular format for tensor serialization. By default, arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of (LoDTensorDesc)[https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99]. We save the DescProto as the byte string header, it contains the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). Tensor stores value in a continuous memory buffer, for speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is, As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**| |HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**|
In detail, tensor's byte view as the table shows. Note that all the signed value written in little-endian. The table below shows a tensor's byte view in detail. Note that all the signed values are written in the little-endian format.
```text ```text
[offset] [type] [description] [offset] [type] [description]
...@@ -33,4 +33,6 @@ In detail, tensor's byte view as the table shows. Note that all the signed valu ...@@ -33,4 +33,6 @@ In detail, tensor's byte view as the table shows. Note that all the signed valu
## Summary ## Summary
We introduce the model format, the `ProgramDesc` describe the **topology**, and a bunch of particular format binary tensors describes the **parameters**. - We introduce a model format.
- The `ProgramDesc` describe the model **topology**.
- A bunch of specified format binary tensors describe the **parameters**.
...@@ -25,9 +25,8 @@ import ( ...@@ -25,9 +25,8 @@ import (
"strings" "strings"
"time" "time"
log "github.com/inconshreveable/log15"
"github.com/namsral/flag" "github.com/namsral/flag"
log "github.com/sirupsen/logrus"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
...@@ -41,16 +40,20 @@ func main() { ...@@ -41,16 +40,20 @@ func main() {
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warn, error, crit")
flag.Parse() flag.Parse()
level, e := log.ParseLevel(*logLevel) lvl, err := log.LvlFromString(*logLevel)
candy.Must(e) if err != nil {
panic(err)
}
log.SetLevel(level) log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
if *endpoints == "" { if *endpoints == "" {
log.Warningln("-endpoints not set, fault tolerance not be enabled.") log.Warn("-endpoints not set, fault tolerance not be enabled.")
} }
var store master.Store var store master.Store
...@@ -58,23 +61,25 @@ func main() { ...@@ -58,23 +61,25 @@ func main() {
eps := strings.Split(*endpoints, ",") eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP() ip, err := networkhelper.GetExternalIP()
if err != nil { if err != nil {
log.Fatal(err) log.Crit("get external ip error", log.Ctx{"error": err})
panic(err)
} }
addr := fmt.Sprintf("%s:%d", ip, *port) addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error creating etcd client.", log.Ctx{"error": err})
panic(err)
} }
} else { } else {
store = &master.InMemStore{} store = &master.InMemStore{}
} }
shutdown := func() { shutdown := func() {
log.Infoln("shutting down gracefully") log.Info("shutting down gracefully")
err := store.Shutdown() err := store.Shutdown()
if err != nil { if err != nil {
log.Errorln(err) log.Error("shutdown error", log.Ctx{"error": err})
} }
} }
...@@ -86,24 +91,28 @@ func main() { ...@@ -86,24 +91,28 @@ func main() {
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error creating new service.", log.Ctx{"error": err})
panic(err)
} }
err = rpc.Register(s) err = rpc.Register(s)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error registering to etcd.", log.Ctx{"error": err})
panic(err)
} }
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
panic(err)
} }
go func() { go func() {
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error serving HTTP", log.Ctx{"error": err})
panic(err)
} }
}() }()
......
...@@ -27,11 +27,11 @@ import ( ...@@ -27,11 +27,11 @@ import (
"github.com/topicai/candy" "github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 8001, "port of the pserver")
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry") index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
...@@ -41,13 +41,17 @@ func main() { ...@@ -41,13 +41,17 @@ func main() {
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warn, error, crit")
flag.Parse() flag.Parse()
level, err := log.ParseLevel(*logLevel) lvl, err := log.LvlFromString(*logLevel)
candy.Must(err) if err != nil {
panic(err)
}
log.SetLevel(level) log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
var idx int var idx int
...@@ -63,7 +67,7 @@ func main() { ...@@ -63,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx) cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil { if err != nil {
if err == pserver.ErrCheckpointNotFound { if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.") log.Info("Could not find the pserver checkpoint.")
} else { } else {
panic(err) panic(err)
} }
...@@ -71,10 +75,10 @@ func main() { ...@@ -71,10 +75,10 @@ func main() {
} }
shutdown := func() { shutdown := func() {
log.Infoln("shutting down gracefully") log.Info("shutting down gracefully")
sErr := e.Shutdown() sErr := e.Shutdown()
if sErr != nil { if sErr != nil {
log.Errorln(sErr) log.Error("error shutting down", log.Ctx{"error": sErr})
} }
} }
...@@ -95,7 +99,7 @@ func main() { ...@@ -95,7 +99,7 @@ func main() {
candy.Must(err) candy.Must(err)
go func() { go func() {
log.Infof("start pserver at port %d", *port) log.Info("starting pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil) err = http.Serve(l, nil)
candy.Must(err) candy.Must(err)
}() }()
......
hash: 328e7b9b7306b45e7b9879139a9f86698115981f6283032e1312093a6a6ddb04 hash: 51d9e2e46d7fd9173ff11ecada40f7b7728756be18d5e2f032535f66465e6e15
updated: 2017-10-16T08:00:23.484693528Z updated: 2017-10-24T15:04:09.987751592-07:00
imports: imports:
- name: github.com/alecthomas/gometalinter - name: github.com/alecthomas/gometalinter
version: bae2f1293d092fd8167939d5108d1b025eaef9de version: bae2f1293d092fd8167939d5108d1b025eaef9de
...@@ -99,6 +99,8 @@ imports: ...@@ -99,6 +99,8 @@ imports:
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml - name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7 version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/go-stack/stack
version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf
- name: github.com/gogo/protobuf - name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages: subpackages:
...@@ -120,8 +122,14 @@ imports: ...@@ -120,8 +122,14 @@ imports:
- runtime - runtime
- runtime/internal - runtime/internal
- utilities - utilities
- name: github.com/inconshreveable/log15
version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3
- name: github.com/jonboulle/clockwork - name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09 version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/mattn/go-colorable
version: 5411d3eea5978e6cdc258b30de592b60df6aba96
- name: github.com/mattn/go-isatty
version: 57fdcb988a5c543893cc61bce354a6e24ab70022
- name: github.com/matttproud/golang_protobuf_extensions - name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages: subpackages:
...@@ -179,11 +187,12 @@ imports: ...@@ -179,11 +187,12 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: 0f826bdd13b500be0f1d4004938ad978fcc6031e version: e48874b42435b4347fc52bdee0424a52abc974d7
repo: https://github.com/golang/sys.git repo: https://github.com/golang/sys.git
vcs: git vcs: git
subpackages: subpackages:
- unix - unix
- windows
- name: golang.org/x/text - name: golang.org/x/text
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720 version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git repo: https://github.com/golang/text.git
...@@ -222,4 +231,3 @@ testImports: ...@@ -222,4 +231,3 @@ testImports:
version: 05e8a0eda380579888eb53c394909df027f06991 version: 05e8a0eda380579888eb53c394909df027f06991
subpackages: subpackages:
- assert - assert
...@@ -26,3 +26,7 @@ import: ...@@ -26,3 +26,7 @@ import:
version: v1.1.0 version: v1.1.0
- package: github.com/alecthomas/gometalinter - package: github.com/alecthomas/gometalinter
version: v1.2.1 version: v1.2.1
- package: github.com/inconshreveable/log15
version: v2.13
- package: github.com/go-stack/stack
version: v1.6.0
...@@ -35,13 +35,19 @@ import ( ...@@ -35,13 +35,19 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client) var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client var curHandle C.paddle_master_client
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
func add(c *master.Client) C.paddle_master_client { func add(c *master.Client) C.paddle_master_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
...@@ -117,7 +123,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -117,7 +123,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
} }
err := c.SetDataset(paths) err := c.SetDataset(paths)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error set dataset", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR return C.PADDLE_MASTER_ERROR
} }
...@@ -167,7 +173,7 @@ func paddle_request_save_model(client C.paddle_master_client, trainerID string, ...@@ -167,7 +173,7 @@ func paddle_request_save_model(client C.paddle_master_client, trainerID string,
c := get(client) c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond) need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error request save model", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR return C.PADDLE_MASTER_ERROR
} }
......
...@@ -21,7 +21,7 @@ import ( ...@@ -21,7 +21,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// Client is the client of the master server. // Client is the client of the master server.
...@@ -75,7 +75,7 @@ func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error { ...@@ -75,7 +75,7 @@ func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
for { for {
err := f() err := f()
if err != nil { if err != nil {
log.Warningln(err) log.Warn("create etcd client error", log.Ctx{"error": err})
} else { } else {
break break
} }
...@@ -135,13 +135,13 @@ func (c *Client) getRecords(passID int) { ...@@ -135,13 +135,13 @@ func (c *Client) getRecords(passID int) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
log.Errorf("getTask error: %s", err) log.Error("getTask error.", log.Ctx{"error": err})
} }
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, e := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
if e != nil { if e != nil {
log.Errorln(e) log.Error("error open chunk", log.Ctx{"error": e})
continue continue
} }
...@@ -152,12 +152,15 @@ func (c *Client) getRecords(passID int) { ...@@ -152,12 +152,15 @@ func (c *Client) getRecords(passID int) {
if s.Err() != nil { if s.Err() != nil {
c.ch <- record{nil, s.Err()} c.ch <- record{nil, s.Err()}
log.Errorln(err, chunk.Path) log.Error(
"error scan chunk",
log.Ctx{"error": err, "path": chunk.Path},
)
} }
err = f.Close() err = f.Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error close record file", log.Ctx{"error": err})
} }
} }
...@@ -166,7 +169,7 @@ func (c *Client) getRecords(passID int) { ...@@ -166,7 +169,7 @@ func (c *Client) getRecords(passID int) {
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
err = c.taskFinished(t.Meta.ID) err = c.taskFinished(t.Meta.ID)
if err != nil { if err != nil {
log.Errorln(err) log.Error("task finish callback error.", log.Ctx{"error": err})
} }
} }
} }
...@@ -179,12 +182,12 @@ func (c *Client) monitorMaster(addrCh <-chan string) { ...@@ -179,12 +182,12 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("close old master addr error", log.Ctx{"error": err})
} }
} else { } else {
err := c.conn.Connect(curMaster) err := c.conn.Connect(curMaster)
if err != nil { if err != nil {
log.Errorln(err) log.Error("connect to new master addr error", log.Ctx{"error": err})
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
......
...@@ -25,8 +25,6 @@ import ( ...@@ -25,8 +25,6 @@ import (
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
...@@ -36,10 +34,6 @@ const ( ...@@ -36,10 +34,6 @@ const (
chunkPerTask = 10 chunkPerTask = 10
) )
func init() {
log.SetLevel(log.ErrorLevel)
}
func TestGetFinishTask(t *testing.T) { func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0" const path = "/tmp/master_client_test_0"
......
...@@ -20,7 +20,7 @@ import ( ...@@ -20,7 +20,7 @@ import (
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -44,7 +44,7 @@ type EtcdClient struct { ...@@ -44,7 +44,7 @@ type EtcdClient struct {
// NewEtcdClient creates a new EtcdClient. // NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints) log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
DialTimeout: dialTimeout, DialTimeout: dialTimeout,
...@@ -64,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -64,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Infof("Trying to acquire lock at %s.", lockPath) log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
err = lock.Lock(context.TODO()) err = lock.Lock(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Infof("Successfully acquired lock at %s.", lockPath) log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
put := clientv3.OpPut(addrPath, addr) put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
...@@ -78,7 +78,8 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -78,7 +78,8 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Fatal("No longer owns the master lock. Exiting.") log.Crit("No longer owns the master lock. Exiting.")
panic("No longer owns the master lock. Exiting.")
} }
e := &EtcdClient{ e := &EtcdClient{
...@@ -102,7 +103,7 @@ func (e *EtcdClient) Save(state []byte) error { ...@@ -102,7 +103,7 @@ func (e *EtcdClient) Save(state []byte) error {
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock again") log.Error("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx) err := e.lock.Lock(ctx)
cancel() cancel()
...@@ -116,9 +117,10 @@ func (e *EtcdClient) Save(state []byte) error { ...@@ -116,9 +117,10 @@ func (e *EtcdClient) Save(state []byte) error {
// to kill current master server. The current // to kill current master server. The current
// state is not saved, but the trainer's RPC // state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry. // call will fail, so the trainer will retry.
log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err) log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
panic("Could not acquire the lock at %s: %v. Exiting.")
} }
log.Infof("Successfully acquired lock at %s.", e.lockPath) log.Info("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state) return e.Save(state)
} }
...@@ -136,7 +138,7 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -136,7 +138,7 @@ func (e *EtcdClient) Load() ([]byte, error) {
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock and load again.") log.Error("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background()) err = e.lock.Lock(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -163,7 +165,7 @@ func (e *EtcdClient) Shutdown() error { ...@@ -163,7 +165,7 @@ func (e *EtcdClient) Shutdown() error {
if err == nil { if err == nil {
err = newErr err = newErr
} else { } else {
log.Errorln(newErr) log.Error("shutdown error", log.Ctx{"error": newErr})
} }
} }
...@@ -192,7 +194,7 @@ func watchKey(c *clientv3.Client, key string, valChan chan<- string) { ...@@ -192,7 +194,7 @@ func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
for wresp := range rch { for wresp := range rch {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string // if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value) log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
valChan <- string(ev.Kv.Value) valChan <- string(ev.Kv.Value)
} }
} }
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
...@@ -170,11 +170,11 @@ func (s *Service) recover() (bool, error) { ...@@ -170,11 +170,11 @@ func (s *Service) recover() (bool, error) {
} }
if state == nil { if state == nil {
log.Infoln("No state exists, not recovered.") log.Info("No state exists, not recovered.")
return false, nil return false, nil
} }
log.Infof("Loaded snapshot of size: %d bytes.", len(state)) log.Info("Loaded snapshot.", log.Ctx{"size": len(state)})
gr, err := gzip.NewReader(bytes.NewReader(state)) gr, err := gzip.NewReader(bytes.NewReader(state))
if err != nil { if err != nil {
return false, err return false, err
...@@ -191,11 +191,11 @@ func (s *Service) recover() (bool, error) { ...@@ -191,11 +191,11 @@ func (s *Service) recover() (bool, error) {
if err != nil { if err != nil {
// Only close failed, recover actually succeed, so // Only close failed, recover actually succeed, so
// just log error. // just log error.
log.Errorln(err) log.Error("error close recover file.", log.Ctx{"error": err})
} }
s.state = tqs s.state = tqs
log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.") log.Info("Master recovered from snapshot, scheduling pending task timeout check.", s.logCtx())
for _, t := range s.state.Pending { for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
} }
...@@ -224,7 +224,7 @@ func (s *Service) snapshot() error { ...@@ -224,7 +224,7 @@ func (s *Service) snapshot() error {
} }
state := buf.Bytes() state := buf.Bytes()
log.Infof("Saving snapshot of size: %d bytes.", len(state)) log.Info("Saving snapshot.", log.Ctx{"size bytes": len(state)})
return s.store.Save(state) return s.store.Save(state)
} }
...@@ -260,7 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -260,7 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
} }
count := index.NumChunks() count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count) log.Info("reading chunks.", log.Ctx{"path": path, "num chunks": count})
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunk := Chunk{ chunk := Chunk{
Path: path, Path: path,
...@@ -300,7 +300,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error { ...@@ -300,7 +300,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
return err return err
} }
close(s.ready) close(s.ready)
...@@ -320,7 +320,7 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { ...@@ -320,7 +320,7 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
defer func() { defer func() {
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
} }
}() }()
...@@ -328,12 +328,12 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { ...@@ -328,12 +328,12 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
t.NumFailure++ t.NumFailure++
if t.NumFailure > s.failureMax { if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warn("Task failed to many times, discard.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Failed = append(s.state.Failed, t) s.state.Failed = append(s.state.Failed, t)
return return
} }
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure) log.Warn("Task failed, re-dispatch.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Todo = append(s.state.Todo, t) s.state.Todo = append(s.state.Todo, t)
return return
} }
...@@ -353,8 +353,8 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -353,8 +353,8 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
} }
// must be called with lock held. // must be called with lock held.
func (s *Service) logFields() log.Fields { func (s *Service) logCtx() log.Ctx {
return log.Fields{ return log.Ctx{
"todoLen": len(s.state.Todo), "todoLen": len(s.state.Todo),
"pendingLen": len(s.state.Pending), "pendingLen": len(s.state.Pending),
"doneLen": len(s.state.Done), "doneLen": len(s.state.Done),
...@@ -383,10 +383,10 @@ func (s *Service) GetTask(passID int, task *Task) error { ...@@ -383,10 +383,10 @@ func (s *Service) GetTask(passID int, task *Task) error {
if len(s.state.Todo) == 0 { if len(s.state.Todo) == 0 {
if len(s.state.Done) == 0 && len(s.state.Pending) == 0 { if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass") log.Warn("All tasks failed, may start next pass", s.logCtx())
return ErrAllTaskFailed return ErrAllTaskFailed
} }
log.WithFields(s.logFields()).Warningln("No more available task.") log.Warn("No more available task.", s.logCtx())
return ErrNoMoreAvailable return ErrNoMoreAvailable
} }
...@@ -400,8 +400,9 @@ func (s *Service) GetTask(passID int, task *Task) error { ...@@ -400,8 +400,9 @@ func (s *Service) GetTask(passID int, task *Task) error {
} }
*task = t.Task *task = t.Task
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta) ctx := s.logCtx()
ctx["task meta"] = t.Task.Meta
log.Info("Task dispatched.", ctx)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil return nil
} }
...@@ -417,7 +418,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -417,7 +418,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.state.Pending[taskID] t, ok := s.state.Pending[taskID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) ctx := s.logCtx()
ctx["task id"] = taskID
log.Warn("Pending task not found.", ctx)
return nil return nil
} }
...@@ -426,7 +429,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -426,7 +429,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = append(s.state.Done, t) s.state.Done = append(s.state.Done, t)
delete(s.state.Pending, taskID) delete(s.state.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) ctx := s.logCtx()
ctx["task id"] = taskID
log.Info("Task finished.", ctx)
if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 { if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
// increase master side pass count if all tasks finished // increase master side pass count if all tasks finished
s.state.CurPass++ s.state.CurPass++
...@@ -434,12 +439,14 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -434,12 +439,14 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = []taskEntry{} s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks // TODO(typhoonzero): deal with failed tasks
s.state.Failed = []taskEntry{} s.state.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass) ctx := s.logCtx()
ctx["new pass"] = s.state.CurPass
log.Warn("all task finished, add new pass data.", ctx)
} }
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
} }
return err return err
} }
...@@ -455,7 +462,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { ...@@ -455,7 +462,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
t, ok := s.state.Pending[meta.ID] t, ok := s.state.Pending[meta.ID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta) log.Warn("TaskFailed:Pending task not found.", log.Ctx{"task": t.Task.Meta})
return nil return nil
} }
......
...@@ -45,9 +45,15 @@ import ( ...@@ -45,9 +45,15 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client" "github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client) var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client var curHandle C.paddle_pserver_client
...@@ -164,10 +170,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, ...@@ -164,10 +170,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name) log.Warn(
"parameter already initialized, treat paddle_init_param as successful.",
log.Ctx{"parameter": name},
)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Error("error init param", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -180,11 +189,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int { ...@@ -180,11 +189,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.") log.Warn("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Error("error finish init params", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -205,7 +214,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient ...@@ -205,7 +214,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient
c := get(client) c := get(client)
err := c.SendGrads(gs) err := c.SendGrads(gs)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error send grads", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -222,7 +231,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -222,7 +231,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
c := get(client) c := get(client)
ps, err := c.GetParams(ns) ps, err := c.GetParams(ns)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error get params", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -231,7 +240,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -231,7 +240,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Error(
"pserver returned wrong number of parameters.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -241,7 +256,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -241,7 +256,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Error(
"pserver returned wrong parameters, or not in requested order.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }
...@@ -251,13 +272,19 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -251,13 +272,19 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nil { if unsafe.Pointer(param) == nil {
log.Errorln("must pre-allocate parameter.") log.Error("must pre-allocate parameter.")
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
if unsafe.Pointer(param.content) != nil { if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) { if int(param.content_len) != len(p.Content) {
log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) log.Error(
"the pre-allocated content len does not match parameter content len.",
log.Ctx{
"Pre-allocated len": param.content_len,
"Returned len": len(p.Content),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
...@@ -84,7 +84,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -84,7 +84,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
if curServers[i].Addr == "" { if curServers[i].Addr == "" {
err := c.pservers[i].Close() err := c.pservers[i].Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error closing connection to pserver", log.Ctx{"error": err})
} }
continue continue
...@@ -92,7 +92,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -92,7 +92,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
err := c.pservers[i].Connect(curServers[i].Addr) err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error connecting to pserver", log.Ctx{"error": err})
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
......
...@@ -30,7 +30,7 @@ import ( ...@@ -30,7 +30,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client" "github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -90,7 +90,7 @@ func initEtcdClient() { ...@@ -90,7 +90,7 @@ func initEtcdClient() {
DialTimeout: time.Second * time.Duration(1), DialTimeout: time.Second * time.Duration(1),
}) })
if err != nil { if err != nil {
log.Errorf("err %v", err) log.Error("error init etcd client", log.Ctx{"error": err})
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err = client.Delete(ctx, pserver.PsDesired) _, err = client.Delete(ctx, pserver.PsDesired)
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -54,26 +54,29 @@ func (e *Etcd) Desired() int { ...@@ -54,26 +54,29 @@ func (e *Etcd) Desired() int {
resp, err := e.client.Get(ctx, pserver.PsDesired) resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) log.Error(
"Get ps dresire number failed! reconnecting...",
log.Ctx{"error": err},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...") log.Info("Waiting for ps desired registered ...")
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %d invalid %v", psDesired, err) log.Error("atoi failed", log.Ctx{"error": err})
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("Get psDesired number: %d", psDesired) log.Debug("Got psDesired", log.Ctx{"psDesired": psDesired})
break break
} }
return psDesired return psDesired
...@@ -88,17 +91,20 @@ func (e *Etcd) List() []Server { ...@@ -88,17 +91,20 @@ func (e *Etcd) List() []Server {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debug("looking for pserver", log.Ctx{"ps key": psKey})
resp, err := e.client.Get(ctx, psKey) resp, err := e.client.Get(ctx, psKey)
cancel() cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Info(
"Get psKey error",
log.Ctx{"ps key": psKey, "error": err},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...") log.Info("Waiting for ps addr registered ...")
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -106,11 +112,17 @@ func (e *Etcd) List() []Server { ...@@ -106,11 +112,17 @@ func (e *Etcd) List() []Server {
psAddr := string(resp.Kvs[0].Value) psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address // TODO(Longfei) check the ps address
if psAddr == "" { if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey) log.Info(
"Value under psKey is empty",
log.Ctx{"psKey": psKey},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("got value (%s) for key: %s", psAddr, psKey) log.Debug(
"got psAddr given psKey",
log.Ctx{"psAddr": psAddr, "psKey": psKey},
)
servers[i].Index = i servers[i].Index = i
servers[i].Addr = psAddr servers[i].Addr = psAddr
} }
...@@ -130,13 +142,13 @@ func NewEtcd(endpoints string) *Etcd { ...@@ -130,13 +142,13 @@ func NewEtcd(endpoints string) *Etcd {
DialTimeout: defaultEtcdTimeout, DialTimeout: defaultEtcdTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("Init etcd connection failed: %v", err) log.Error("Init etcd connection failed", log.Ctx{"error": err})
time.Sleep(defaultEtcdTimeout) time.Sleep(defaultEtcdTimeout)
continue continue
} }
break break
} }
log.Infof("Connected to etcd: %s\n", endpoints) log.Info("Connected to etcd endpoint", log.Ctx{"endpoint": endpoints})
client := &Etcd{ client := &Etcd{
client: cli, client: cli,
timeout: defaultEtcdTimeout, timeout: defaultEtcdTimeout,
...@@ -154,7 +166,7 @@ func (e *Etcd) Select() (bool, error) { ...@@ -154,7 +166,7 @@ func (e *Etcd) Select() (bool, error) {
} }
lock := concurrency.NewMutex(sess, initLockPath) lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath) log.Info("Trying to acquire lock", log.Ctx{"lock path": initLockPath})
// Do not use timeout context here, since we don't know how // Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the // long does it take for other trainers to initialize the
// parameters. // parameters.
...@@ -162,7 +174,7 @@ func (e *Etcd) Select() (bool, error) { ...@@ -162,7 +174,7 @@ func (e *Etcd) Select() (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
log.Infof("Successfully acquired lock at %s.", initLockPath) log.Info("Successfully acquired lock", log.Ctx{"lock path": initLockPath})
get := clientv3.OpGet(initDonePath) get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
...@@ -181,17 +193,17 @@ func (e *Etcd) Select() (bool, error) { ...@@ -181,17 +193,17 @@ func (e *Etcd) Select() (bool, error) {
if len(resp.Kvs) == 0 { if len(resp.Kvs) == 0 {
// Key value not set, select current trainer. // Key value not set, select current trainer.
e.lock = lock e.lock = lock
log.Infoln("Trainer selected.") log.Info("Trainer selected.")
return true, nil return true, nil
} }
if string(resp.Kvs[0].Value) == initDoneVal { if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.") log.Info("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout) ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx) err = lock.Unlock(ctx)
cancel() cancel()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error unlocking", log.Ctx{"error": err})
} }
return false, nil return false, nil
} }
...@@ -221,7 +233,7 @@ func (e *Etcd) Done() error { ...@@ -221,7 +233,7 @@ func (e *Etcd) Done() error {
err = e.lock.Unlock(ctx) err = e.lock.Unlock(ctx)
cancel() cancel()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error unlocking", log.Ctx{"error": err})
} else { } else {
e.lock = nil e.lock = nil
} }
...@@ -244,7 +256,7 @@ func (e *Etcd) Close() error { ...@@ -244,7 +256,7 @@ func (e *Etcd) Close() error {
cErr := e.client.Close() cErr := e.client.Close()
if cErr != nil { if cErr != nil {
if err != nil { if err != nil {
log.Errorln(cErr) log.Error("error closing etcd client", log.Ctx{"error": cErr})
return err return err
} }
return cErr return cErr
......
...@@ -24,7 +24,7 @@ import ( ...@@ -24,7 +24,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -82,19 +82,19 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -82,19 +82,19 @@ func (e *EtcdClient) Register(port int) (int, error) {
DialTimeout: e.dialTimeout, DialTimeout: e.dialTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("connect to etcd error: %v", err) log.Error("connect to etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.client = cli e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec)) sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil { if err != nil {
log.Errorf("create etcd session error: %v", err) log.Error("create etcd session error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.sess = sess e.sess = sess
log.Debugf("inited client to %s", e.endpoints) log.Debug("connected to etcd", log.Ctx{"endpoint": e.endpoints})
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write // init /ps_desired using transaction, for multiple pservers may want to write
...@@ -104,7 +104,7 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -104,7 +104,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
_, err := e.initDesiredPservers(ctx, e.numPservers) _, err := e.initDesiredPservers(ctx, e.numPservers)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn("pserver init error", log.Ctx{"error": err, "num pservers": e.numPservers})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
...@@ -119,14 +119,17 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -119,14 +119,17 @@ func (e *EtcdClient) Register(port int) (int, error) {
resp, err := e.client.Get(ctx, PsDesired) resp, err := e.client.Get(ctx, PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err) log.Error("get etcd key error", log.Ctx{"key": PsDesired, "error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
if len(resp.Kvs) != 0 { if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err) log.Error(
"psDesired atoi error",
log.Ctx{"error": err, "value": string(resp.Kvs[0].Value)},
)
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change // NOTE: wait util ps_desired value change
continue continue
...@@ -143,7 +146,7 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -143,7 +146,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
pserverIdx, err = e.registerPserverEtcd(ctx, port) pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn("register pserver on etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
...@@ -170,16 +173,17 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er ...@@ -170,16 +173,17 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey) ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debug(
"register pserver got value",
log.Ctx{"value": ps, "key": psKey},
)
if ps == "" { if ps == "" {
// find the first id and write info // find the first id and write info
pserverAddr := e.externalIP + ":" + strconv.Itoa(port) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease())) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) log.Debug("register finished", log.Ctx{"key": psKey, "value": pserverAddr})
log.Debug("register finished")
idx = i idx = i
registered = true registered = true
break break
...@@ -239,7 +243,7 @@ func (e *EtcdClient) Shutdown() error { ...@@ -239,7 +243,7 @@ func (e *EtcdClient) Shutdown() error {
newErr := e.client.Close() newErr := e.client.Close()
if newErr != nil { if newErr != nil {
if err != nil { if err != nil {
log.Errorln(newErr) log.Error("shutdown error", log.Ctx{"error": newErr})
} else { } else {
err = newErr err = newErr
} }
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"fmt" "fmt"
"unsafe" "unsafe"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
type optimizer struct { type optimizer struct {
...@@ -56,12 +56,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -56,12 +56,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State s := State
paramBufferSize := C.size_t(len(p.Content)) paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{ log.Info("New Optimizer Created with config", log.Ctx{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": paramBufferSize, "ParamSize": paramBufferSize,
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s), "StateSize": len(s),
}).Info("New Optimizer Created with config:") })
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(paramBufferSize) cbuffer = C.malloc(paramBufferSize)
......
...@@ -32,7 +32,7 @@ import ( ...@@ -32,7 +32,7 @@ import (
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -124,6 +124,9 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) { ...@@ -124,6 +124,9 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
// LoadCheckpoint loads checkpoint from file. // LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")
cpMeta, err := loadMeta(e, idx) cpMeta, err := loadMeta(e, idx)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -178,6 +181,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient ...@@ -178,6 +181,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
log.Warn("init param called but parameters already initialized.")
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -191,6 +195,13 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error ...@@ -191,6 +195,13 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
log.Info(
"init parameter",
"name", paramWithConfigs.Param.Name,
"config len", len(paramWithConfigs.Config),
"param len", len(paramWithConfigs.Param.Content),
"type", paramWithConfigs.Param.ElementType,
)
return nil return nil
} }
...@@ -199,6 +210,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error ...@@ -199,6 +210,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error
func (s *Service) FinishInitParams(_ int, _ *int) error { func (s *Service) FinishInitParams(_ int, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
log.Warn("finished init param called but parameters already initialized.")
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -209,10 +221,12 @@ func (s *Service) FinishInitParams(_ int, _ *int) error { ...@@ -209,10 +221,12 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t { for range t {
err := s.checkpoint() err := s.checkpoint()
if err != nil { if err != nil {
log.Errorln(err) log.Error("finish init params error", log.Ctx{"error": err})
} }
} }
}() }()
log.Info("init parameter finished.")
return nil return nil
} }
...@@ -222,6 +236,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error { ...@@ -222,6 +236,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
log.Warn("received gradient before initialization.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
return errors.New(Uninitialized) return errors.New(Uninitialized)
} }
...@@ -233,6 +248,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error { ...@@ -233,6 +248,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
return fmt.Errorf("parameter: %s does not exist", g.Name) return fmt.Errorf("parameter: %s does not exist", g.Name)
} }
log.Info("received gradient from trainer, updating gradient.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
return o.UpdateParameter(g) return o.UpdateParameter(g)
} }
...@@ -244,6 +260,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -244,6 +260,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
opt, ok := s.optMap[name] opt, ok := s.optMap[name]
if !ok { if !ok {
log.Warn("trainer wants to get a parameter that does not exist.", "name", name)
return fmt.Errorf("parameter: %s does not exist", name) return fmt.Errorf("parameter: %s does not exist", name)
} }
...@@ -257,12 +274,13 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -257,12 +274,13 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name parameter.Name = name
parameter.ElementType = opt.elementType parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights() parameter.Content = opt.GetWeights()
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil return nil
} }
func traceTime(start time.Time, name string) { func traceTime(start time.Time, name string) {
elapsed := time.Since(start) elapsed := time.Since(start)
log.Infof("%s took %v", name, elapsed) log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed})
} }
// checkpoint saves checkpoint to disk. // checkpoint saves checkpoint to disk.
...@@ -270,7 +288,7 @@ func traceTime(start time.Time, name string) { ...@@ -270,7 +288,7 @@ func traceTime(start time.Time, name string) {
// checkpoint should be only called after the parameters are // checkpoint should be only called after the parameters are
// initialized. // initialized.
func (s *Service) checkpoint() (err error) { func (s *Service) checkpoint() (err error) {
log.Infoln("Begin save checkpoint.") log.Info("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint") defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock() s.mu.Lock()
...@@ -315,7 +333,7 @@ func (s *Service) checkpoint() (err error) { ...@@ -315,7 +333,7 @@ func (s *Service) checkpoint() (err error) {
closeErr := f.Close() closeErr := f.Close()
if closeErr != nil { if closeErr != nil {
if err != nil { if err != nil {
log.Errorln(closeErr) log.Error("error close checkpoint file", log.Ctx{"error": closeErr})
} else { } else {
// Set closeErr as return value. // Set closeErr as return value.
err = closeErr err = closeErr
...@@ -336,7 +354,7 @@ func (s *Service) checkpoint() (err error) { ...@@ -336,7 +354,7 @@ func (s *Service) checkpoint() (err error) {
oldMeta, err := loadMeta(s.client, s.idx) oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound { if err == ErrCheckpointNotFound {
log.Infoln("Do not have existing checkpoint.") log.Info("Do not have existing checkpoint.")
err = nil err = nil
} }
...@@ -368,7 +386,7 @@ func (s *Service) checkpoint() (err error) { ...@@ -368,7 +386,7 @@ func (s *Service) checkpoint() (err error) {
if rmErr != nil { if rmErr != nil {
// log error, but still treat checkpoint as // log error, but still treat checkpoint as
// successful. // successful.
log.Errorln(rmErr) log.Error("remove old meta file error", log.Ctx{"error": rmErr})
} }
} }
......
...@@ -106,6 +106,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const { ...@@ -106,6 +106,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const {
return lod_[level][idx + 1] - lod_[level][idx]; return lod_[level][idx + 1] - lod_[level][idx];
} }
size_t LoDTensor::NumInstancesInElement(size_t level, size_t idx) const {
PADDLE_ENFORCE_LT(level, NumLevels());
PADDLE_ENFORCE_LT(idx, NumElements(level));
auto abs_lod = ToAbsOffset(lod());
size_t begin = abs_lod[level][idx];
size_t end = abs_lod[level][idx + 1];
return end - begin;
}
void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) { void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) {
auto new_lod = framework::SliceLevels(lod_, level_begin, level_end); auto new_lod = framework::SliceLevels(lod_, level_begin, level_end);
lod_ = new_lod; lod_ = new_lod;
...@@ -117,8 +126,15 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, ...@@ -117,8 +126,15 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
PADDLE_ENFORCE_LT(elem_begin, NumElements(level)); PADDLE_ENFORCE_LT(elem_begin, NumElements(level));
PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1); PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1);
auto abs_lod = framework::ToAbsOffset(lod());
auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end); auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end);
lod_ = new_lod; lod_ = new_lod;
// slice the underlying tensor
size_t begin = abs_lod[level][elem_begin];
size_t end = abs_lod[level][elem_end];
PADDLE_ENFORCE_LT(begin, end, "Cannot shrink, the result tensor is empty.");
ShareDataWith(Slice(begin, end));
} }
std::string LoDTensor::SerializeToString() const { std::string LoDTensor::SerializeToString() const {
......
...@@ -122,6 +122,12 @@ class LoDTensor : public Tensor { ...@@ -122,6 +122,12 @@ class LoDTensor : public Tensor {
*/ */
size_t NumElements(size_t level, size_t idx) const; size_t NumElements(size_t level, size_t idx) const;
/*
* Get the number of instances in the underlying tensor in the `idx`-th
* element.
*/
size_t NumInstancesInElement(size_t level, size_t idx) const;
/* /*
* Shrink levels[level_begin:level_end] * Shrink levels[level_begin:level_end]
*/ */
...@@ -157,5 +163,42 @@ class LoDTensor : public Tensor { ...@@ -157,5 +163,42 @@ class LoDTensor : public Tensor {
private: private:
LoD lod_; LoD lod_;
}; };
/*
* Expand the `source` to fit the LoD of `lod`. For example, a `source`
* LoDTensor is
* - LoD: [0, 2]
* - tensor: [a0, a1]
* a `lod` is
* - LoD: [0 3 5]
* returns a new LoDTensor
* - [a0 a0 a0 a1 a1]
*/
template <typename T>
LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
const platform::Place& place) {
LoD abs_lod = ToAbsOffset(lod);
const auto& lod_level = lod[level];
size_t num_instances = source.dims()[0];
// new tensor
LoDTensor tensor;
tensor.set_lod(lod);
auto dims = source.dims();
dims[0] = lod_level.back();
tensor.Resize(dims);
tensor.mutable_data<T>(place);
PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
tensor.Slice(elem, elem + 1)
.CopyFrom(source.Slice(ins, ins + 1), platform::CPUPlace(),
platform::CPUDeviceContext());
}
}
return tensor;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -92,11 +92,14 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { ...@@ -92,11 +92,14 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
size_t level = 0; size_t level = 0;
LoDTensor new_lod_tensor = lod_tensor_; LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.ShrinkInLevel(level, 0, 1); new_lod_tensor.ShrinkInLevel(level, 0, 1);
EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL); ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
EXPECT_EQ(new_lod_tensor.NumElements(0), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(2), 5UL); ASSERT_EQ(new_lod_tensor.NumElements(2), 5UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>()); ASSERT_EQ(new_lod_tensor.dims()[0], 12);
for (int i = 0; i < 12 * 128; i++) {
ASSERT_EQ(new_lod_tensor.data<float>()[i], i);
}
level = 1; level = 1;
new_lod_tensor = lod_tensor_; new_lod_tensor = lod_tensor_;
...@@ -104,7 +107,41 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { ...@@ -104,7 +107,41 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 3UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 3UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>()); ASSERT_EQ(new_lod_tensor.dims()[0], 7);
for (int i = 5 * 128; i < 12 * 128; i++) {
ASSERT_EQ(new_lod_tensor.data<float>()[i - 5 * 128], i);
}
LoDTensor t1;
t1.set_lod(lod_tensor_.lod());
t1.ShareDataWith(lod_tensor_);
LoDTensor t2;
t2.set_lod(lod_tensor_.lod());
t2.ShareDataWith(lod_tensor_);
t1.ShrinkInLevel(0, 1, 2);
t2.ShrinkInLevel(0, 0, 1);
EXPECT_NE(t1.data<float>(), t2.data<float>());
EXPECT_NE(t1.data<float>(), lod_tensor_.data<float>());
}
TEST(LodExpand, test) {
LoD lod{{0, 2}};
LoDTensor tensor;
tensor.set_lod(lod);
tensor.Resize({2, 1});
tensor.mutable_data<float>(platform::CPUPlace());
tensor.data<float>()[0] = 0;
tensor.data<float>()[1] = 1;
LoD target;
target.emplace_back(std::vector<size_t>{0, 3, 5});
auto new_tensor = LodExpand<float>(tensor, target, 0UL, platform::CPUPlace());
std::vector<int> result{{0, 0, 0, 1, 1}};
for (size_t i = 0; i < 5; i++) {
ASSERT_EQ(new_tensor.data<float>()[i], result[i]);
}
} }
TEST_F(LoDTensorTester, SerializeDeserialize) { TEST_F(LoDTensorTester, SerializeDeserialize) {
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include "paddle/framework/eigen.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -104,10 +106,10 @@ void TensorArray::Write(size_t index, const LoDTensor& value) { ...@@ -104,10 +106,10 @@ void TensorArray::Write(size_t index, const LoDTensor& value) {
values_.resize(index + 1); values_.resize(index + 1);
} }
values_[index].set_lod(value.lod());
values_[index].Resize(value.dims()); values_[index].Resize(value.dims());
values_[index].mutable_data<value_type>(platform::CPUPlace()); values_[index].mutable_data<value_type>(value.place());
values_[index].CopyFrom(value, platform::CPUPlace(), values_[index].CopyFrom(value, value.place(), platform::CPUDeviceContext());
platform::CPUDeviceContext());
} }
void TensorArray::WriteShared(size_t index, const LoDTensor& value) { void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
...@@ -116,6 +118,7 @@ void TensorArray::WriteShared(size_t index, const LoDTensor& value) { ...@@ -116,6 +118,7 @@ void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
values_.resize(index + 1); values_.resize(index + 1);
} }
values_[index].set_lod(value.lod());
values_[index].ShareDataWith(value); values_[index].ShareDataWith(value);
} }
...@@ -144,6 +147,156 @@ DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level, ...@@ -144,6 +147,156 @@ DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level,
return unpacker.meta; return unpacker.meta;
} }
LoDTensor TensorArray::LodPack(size_t level) const {
PADDLE_ENFORCE_GT(size(), 0UL, "no time step exists");
// the levels should be no less than 2
LoDTensor merged;
const LoDTensor *pre, *cur;
pre = &Read(0);
for (size_t step = 1; step < size(); step++) {
cur = &Read(step);
PADDLE_ENFORCE_GT(cur->NumLevels(), 0);
PADDLE_ENFORCE_GT(pre->NumLevels(), 0);
PADDLE_ENFORCE_EQ(pre->NumLevels(), cur->NumLevels());
PADDLE_ENFORCE_EQ(pre->NumElements(level), cur->NumElements(level));
merged = LodPackTwo(*pre, *cur, level);
pre = &merged;
}
return merged;
}
/*
* NOTE currently, only the lowest level supports packing.
* The lowest LoD will be changed, while the relative offsets in levels above
* stay unchanged.
*
* previous step : [0] [1] [3]
* current step: [0 1 2] [2 3] []
* packed to
* [0 0] [0 1] [0 2] [1 2] [1 3] [3]
*/
LoDTensor TensorArray::LodPackTwo(const LoDTensor& pre, const LoDTensor& cur,
size_t level) const {
PADDLE_ENFORCE_EQ(pre.NumLevels(), cur.NumLevels());
PADDLE_ENFORCE_EQ(pre.NumLevels(), level + 1,
"Only the lowest LoD level supports pack temporarily.");
// calculate the result tensor's shape first
size_t num_instances = 0;
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
if (num_candidates > 0) {
num_instances += num_candidates * (prefix_size + 1);
} else {
num_instances += prefix_size;
}
}
auto res_dims = pre.dims();
res_dims[0] = num_instances;
LoDTensor result;
result.Resize(res_dims);
result.mutable_data<value_type>(cur.place());
Vector<size_t> last_lod_level;
// copy data
size_t index = 0;
last_lod_level.push_back(index);
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
// slice the prefix Tensor
LoDTensor prefix = pre;
prefix.ShrinkInLevel(level, elem, elem + 1);
LoDTensor candidate = cur;
if (num_candidates > 0) {
candidate.ShrinkInLevel(level, elem, elem + 1);
} else { // just push prefix
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
last_lod_level.push_back(index);
}
for (size_t candi = 0; candi < num_candidates; candi++) {
// TODO(superjom) support GPU
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
// copy candidate record
result.Slice(index, index + 1)
.CopyFrom(candidate.Slice(candi, candi + 1), result.place(),
platform::CPUDeviceContext());
index++;
last_lod_level.push_back(index);
}
}
// update lod
auto lod = cur.lod();
lod.back() = last_lod_level;
result.set_lod(lod);
return result;
}
/*
* source [0 1 2] [3 4] [5 6 7] will be transformd to a list of LoDTensors such
* as
* [0 3 5] [1 4 6] [2 7] with 1-level LoDs:
* - [0 1 2 3]
* - [0 1 2 3]
* - [0 1 1 2], the [1,1) here means the second sequence is empty
*
* NOTE Unpack a LoDTensor in this approach may result in a big LoD.
*/
void TensorArray::LodUnpack(const LoDTensor& source, size_t level) {
PADDLE_ENFORCE_EQ(level, source.NumLevels() - 1,
"only the lowest LoD level supports unpack.");
int non_empty_instances = -1;
size_t index = 0;
Vector<size_t> lowest_lod_level;
lowest_lod_level.push_back(index);
for (size_t step = 0; non_empty_instances > 0 || non_empty_instances == -1;
step++) {
size_t num_instances = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
num_instances++;
index++;
}
lowest_lod_level.push_back(index);
}
// create tensor for this time step
LoDTensor tensor;
auto dims = source.dims();
dims[0] = num_instances;
// set lod
auto lod = source.lod();
lod.back() = lowest_lod_level;
tensor.set_lod(lod);
index = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
// copy this instance
tensor.Slice(index, index + 1)
.CopyFrom(instance.Slice(step, step + 1), tensor.place(),
platform::CPUDeviceContext());
index++;
}
}
Write(step, tensor);
}
}
LoDTensor TensorArray::Stack() const { LoDTensor TensorArray::Stack() const {
LoDTensor result; LoDTensor result;
if (size() == 0) return result; if (size() == 0) return result;
......
...@@ -86,6 +86,16 @@ class TensorArray { ...@@ -86,6 +86,16 @@ class TensorArray {
*/ */
DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend); DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend);
/*
* Pack an array of LoDTensors to a LoDTensor.
*/
LoDTensor LodPack(size_t level) const;
/*
* Unpack a LoDTensor to an array of LoDTensors.
*/
void LodUnpack(const LoDTensor &source, size_t level);
/* /*
* Pack the values into a tensor with rank one higher than each tensor in * Pack the values into a tensor with rank one higher than each tensor in
* values. * values.
...@@ -111,6 +121,9 @@ class TensorArray { ...@@ -111,6 +121,9 @@ class TensorArray {
protected: protected:
void Unstack(const LoDTensor &source, bool data_shared) const; void Unstack(const LoDTensor &source, bool data_shared) const;
LoDTensor LodPackTwo(const LoDTensor &pre, const LoDTensor &cur,
size_t level) const;
private: private:
mutable std::vector<LoDTensor> values_; mutable std::vector<LoDTensor> values_;
}; // class TensorArray }; // class TensorArray
......
...@@ -126,5 +126,57 @@ TEST_F(TensorArrayTester, size) { ...@@ -126,5 +126,57 @@ TEST_F(TensorArrayTester, size) {
ASSERT_EQ(ta.size(), static_cast<size_t>(batch_size)); ASSERT_EQ(ta.size(), static_cast<size_t>(batch_size));
} }
TEST(TensorArray, LodPack) {
// three time steps, each step stores a LoDTensors
// - [0] [1]
// - [2 3], [4 5]
// - [6 7] [] [8], [9, 10]
// try to get a LoDTensor with content:
// - [0 2 6]
// - [0 2 7]
// - [0 3]
// - [1 4 8]
// - [1 5 9]
// - [1 5 10]
std::array<LoDTensor, 3> tensors;
tensors[0].Resize(make_ddim({2, 1}));
tensors[1].Resize(make_ddim({4, 1}));
tensors[2].Resize(make_ddim({5, 1}));
int index = 0;
for (auto& t : tensors) {
t.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < t.dims()[0]; i++) {
t.data<int>()[i] = index;
index++;
}
}
std::array<LoD, 3> lods;
std::vector<std::vector<size_t>> levels{
{0, 1, 2}, {0, 2, 4}, {0, 2, 2, 3, 5}};
for (int i = 0; i < 3; i++) {
lods[i].emplace_back(levels[i].begin(), levels[i].end());
}
TensorArray ta;
for (int i = 0; i < 3; i++) {
tensors[i].set_lod(lods[i]);
ta.Write(i, tensors[i]);
}
auto merged = ta.LodPack(0);
std::vector<int> target_tensor_data{{0, 2, 6, // 0
0, 2, 7, // 1
0, 3, // 2
1, 4, 8, // 3
1, 5, 9, // 5
1, 5, 10}};
EXPECT_EQ(merged.dims()[0], (int)target_tensor_data.size());
for (size_t i = 0; i < target_tensor_data.size(); i++) {
EXPECT_EQ(target_tensor_data[i], merged.data<int>()[i]);
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/batch_norm_op.h"
#include <cfloat>
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
void ExtractNCWHD(const framework::DDim &dims,
const TensorFormat &tensor_format, int *N, int *C, int *H,
int *W, int *D) {
*N = dims[0];
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
: 1;
}
template <typename T>
class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
VLOG(1) << "Setting descriptors.";
std::vector<int> dims;
std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * D * C, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), saved_mean, 0);
functor(ctx.device_context(), saved_variance, 0);
// FIXME(qiao) should not set zero self
functor(ctx.device_context(), mean_out, 0);
functor(ctx.device_context(), variance_out, 0);
auto handle = ctx.cuda_device_context().cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths.
if (is_test) {
// only when test we use input to do computation.
const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance");
// Run inference mode.
PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(est_mean->dims()[0], C);
PADDLE_ENFORCE_EQ(est_var->dims()[0], C);
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardInference(
handle,
// Note: PERSISTENT not implemented for inference
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
} else {
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// initialize them.
double this_factor = 1. - momentum;
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), bias->template data<T>(), this_factor,
mean_out->template mutable_data<T>(ctx.GetPlace()),
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean->template mutable_data<T>(ctx.GetPlace()),
saved_variance->template mutable_data<T>(ctx.GetPlace())));
}
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};
template <typename T>
class BatchNormGradKernel<platform::GPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
ctx.cuda_device_context().cudnn_handle(), mode_,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
x->template data<T>(), data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(),
d_scale->template mutable_data<T>(ctx.GetPlace()),
d_bias->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean_data, saved_var_data));
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(batch_norm,
ops::BatchNormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::GPUPlace, float>);
...@@ -15,7 +15,8 @@ template <class T> ...@@ -15,7 +15,8 @@ template <class T>
class TensorT { class TensorT {
public: public:
TensorT(size_t size) : height_(1), width_(size) { TensorT(size_t size) : height_(1), width_(size) {
data_ptr_ = std::shared_ptr<T>(new T[size], std::default_delete<T[]>()); // new T[size]() initializes all element to zero value.
data_ptr_ = std::shared_ptr<T>(new T[size](), std::default_delete<T[]>());
data_ = data_ptr_.get(); data_ = data_ptr_.get();
} }
......
...@@ -22,6 +22,47 @@ limitations under the License. */ ...@@ -22,6 +22,47 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
inline const char* cudnnGetErrorString(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
default:
return "Unknown cudnn error number";
}
}
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
#define CUDNN_ENFORCE(condition) \
do { \
cudnnStatus_t status = condition; \
if (status != CUDNN_STATUS_SUCCESS) { \
VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \
PADDLE_THROW("cuDNN call failed"); \
} \
} while (false)
enum class DataLayout { enum class DataLayout {
kNHWC, kNHWC,
kNCHW, kNCHW,
...@@ -40,12 +81,30 @@ template <> ...@@ -40,12 +81,30 @@ template <>
class CudnnDataType<float> { class CudnnDataType<float> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT; static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
}; };
template <> template <>
class CudnnDataType<double> { class CudnnDataType<double> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
}; };
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
......
...@@ -83,6 +83,7 @@ extern void* cudnn_dso_handle; ...@@ -83,6 +83,7 @@ extern void* cudnn_dso_handle;
__macro(cudnnDestroyConvolutionDescriptor); \ __macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \ __macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \ __macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreate); \ __macro(cudnnCreate); \
__macro(cudnnDestroy); \ __macro(cudnnDestroy); \
__macro(cudnnSetStream); \ __macro(cudnnSetStream); \
......
...@@ -43,11 +43,6 @@ void NewRemoteParameterUpdater::init( ...@@ -43,11 +43,6 @@ void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) { const std::vector<ParameterPtr> &parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
for (auto &para : parameters_) {
para->getBuf(PARAMETER_VALUE)->zeroMem();
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
// create parameter server client. // create parameter server client.
if (useEtcd_) { if (useEtcd_) {
parameterClient_ = parameterClient_ =
...@@ -109,6 +104,8 @@ void NewRemoteParameterUpdater::init( ...@@ -109,6 +104,8 @@ void NewRemoteParameterUpdater::init(
LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: " LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: "
<< trainerConfig_.learning_rate_schedule() << ", set to const"; << trainerConfig_.learning_rate_schedule() << ", set to const";
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const); optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
trainerConfig_.learning_rate());
} }
// overwrite optimizerConfigV2 for per-parameter(layer) configs // overwrite optimizerConfigV2 for per-parameter(layer) configs
......
...@@ -65,7 +65,14 @@ def download(url, module_name, md5sum): ...@@ -65,7 +65,14 @@ def download(url, module_name, md5sum):
os.makedirs(dirname) os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1]) filename = os.path.join(dirname, url.split('/')[-1])
if not (os.path.exists(filename) and md5file(filename) == md5sum): retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError("Cannot download {0} within retry limit {2}".
format(url, retry_limit))
print "Cache file %s not found, downloading %s" % (filename, url) print "Cache file %s not found, downloading %s" % (filename, url)
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
......
...@@ -211,13 +211,14 @@ class MomentumOptimizer(Optimizer): ...@@ -211,13 +211,14 @@ class MomentumOptimizer(Optimizer):
""" """
_velocity_acc_str = "velocity" _velocity_acc_str = "velocity"
def __init__(self, learning_rate, momentum): def __init__(self, learning_rate, momentum, use_nesterov=False):
assert learning_rate is not None assert learning_rate is not None
assert momentum is not None assert momentum is not None
super(MomentumOptimizer, self).__init__() super(MomentumOptimizer, self).__init__()
self.type = "momentum" self.type = "momentum"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._momentum = momentum self._momentum = momentum
self._use_nesterov = bool(use_nesterov)
def _initialize_tensors(self, block): def _initialize_tensors(self, block):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -259,7 +260,8 @@ class MomentumOptimizer(Optimizer): ...@@ -259,7 +260,8 @@ class MomentumOptimizer(Optimizer):
"ParamOut": param_and_grad[0], "ParamOut": param_and_grad[0],
"VelocityOut": velocity_acc "VelocityOut": velocity_acc
}, },
attrs={"mu": self._momentum}) attrs={"mu": self._momentum,
"useNesterov": self._use_nesterov})
return momentum_op return momentum_op
......
...@@ -36,7 +36,7 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestMomentumOptimizer(unittest.TestCase):
def get_velocity_str(self): def get_velocity_str(self):
return self._velocity_acc_str return self._velocity_acc_str
def test_momentum_optimizer(self): def test_vanilla_momentum_optimizer(self):
program = framework.Program() program = framework.Program()
block = program.global_block() block = program.global_block()
mul_x = block.create_parameter( mul_x = block.create_parameter(
...@@ -60,6 +60,42 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -60,6 +60,42 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 1)
sgd_op = opts[0] sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum") self.assertEqual(sgd_op.type, "momentum")
self.assertFalse(sgd_op.attr('useNesterov'))
# Check accumulators
accumulators = momentum_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 1)
self.assertTrue(momentum_optimizer.get_velocity_str() in accumulators)
velocity_acc = accumulators[momentum_optimizer.get_velocity_str()]
self.assertEqual(len(velocity_acc), 1)
self.assertTrue(mul_x.name in velocity_acc)
def test_nesterov_momentum_optimizer(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
momentum_optimizer = self.MockMomentum(
learning_rate=0.01, momentum=0.2, use_nesterov=True)
params_grads = append_backward_ops(mul_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(params_grads,
mul_out)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertTrue(sgd_op.attr('useNesterov'))
# Check accumulators # Check accumulators
accumulators = momentum_optimizer.get_accumulators() accumulators = momentum_optimizer.get_accumulators()
......
...@@ -49,7 +49,7 @@ def save_model(parameters, path): ...@@ -49,7 +49,7 @@ def save_model(parameters, path):
' in environment variable.') ' in environment variable.')
etcd_ip = os.environ.get(etcd_name) etcd_ip = os.environ.get(etcd_name)
client = master.client("http://" + etcd_ip + ":2379", 5, 0) client = paddle.v2.master.client("http://" + etcd_ip + ":2379", 5, 0)
r = client.request_save_model(trainer_id, 5000) r = client.request_save_model(trainer_id, 5000)
if r == 0: if r == 0:
# do not need to save # do not need to save
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册