diff --git a/cmake/external/nccl.cmake b/cmake/external/nccl.cmake
index dfbbed58c9ed7cc57809b3d33a29ce26a35d75a2..57d2c0a352507afd01d1cbf2c7b23c00ff7ad81b 100644
--- a/cmake/external/nccl.cmake
+++ b/cmake/external/nccl.cmake
@@ -1,9 +1,8 @@
-INCLUDE(ExternalProject)
+include(ExternalProject)
-SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
-
-INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
+set(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
+include_directories(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies
@@ -12,39 +11,39 @@ if(WITH_DSO)
set(NCCL_INSTALL_DIR "")
else()
# otherwise, we build nccl and link it.
+ set(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
+ # Note: cuda 8.0 is needed to make nccl
+ # When cuda is not installed on the system directory, need to set CUDA_HOME to your cuda root
set(NCCL_BUILD_COMMAND "make -j 8")
- set(NCCL_INSTALL_COMMAND "make install")
- SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
+ set(NCCL_INSTALL_COMMAND "make install PREFIX=${NCCL_INSTALL_DIR}")
endif()
ExternalProject_Add(
- extern_nccl
- ${EXTERNAL_PROJECT_LOG_ARGS}
- GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
- GIT_TAG "v1.3.4-1"
- PREFIX "${NCCL_SOURCE_DIR}"
- UPDATE_COMMAND ""
- CONFIGURE_COMMAND ""
- BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
- INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
- INSTALL_DIR "${NCCL_INSTALL_DIR}"
- TEST_COMMAND ""
+ extern_nccl
+ ${EXTERNAL_PROJECT_LOG_ARGS}
+ GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
+ GIT_TAG "v1.3.4-1"
+ PREFIX "${NCCL_SOURCE_DIR}"
+ UPDATE_COMMAND ""
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
+ INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
+ INSTALL_DIR "${NCCL_INSTALL_DIR}"
+ TEST_COMMAND ""
)
-if (WITH_DSO)
- if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
- set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c)
- file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
+if(WITH_DSO)
+ if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
+ set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_nccl_dummy.c)
+ file(WRITE ${dummyfile} "const char * dummy_nccl = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile})
else()
add_library(nccl INTERFACE)
endif()
else()
- ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL)
- SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION
- ${NCCL_INSTALL_DIR}/lib/libnccl.a)
+ add_library(nccl STATIC IMPORTED GLOBAL)
+ set_property(TARGET nccl PROPERTY IMPORTED_LOCATION
+ ${NCCL_INSTALL_DIR}/lib/libnccl_static.a)
endif()
add_dependencies(nccl extern_nccl)
-
-LIST(APPEND external_project_dependencies nccl)
diff --git a/doc/design/model_format.md b/doc/design/model_format.md
index db8c36e5f5dca94b516aad2134c1bdc8ccc6c744..e29129fddf775939c9f7a8b49d850d523e6e5a45 100644
--- a/doc/design/model_format.md
+++ b/doc/design/model_format.md
@@ -2,35 +2,35 @@
## Motivation
-The model is the output of training process. One complete model consists of two parts, namely, the **topology** and the **parameters**. To support industrial deployment, we need to make the model format must be self-completed and do not expose any training source code.
+A model is an output of the training process. One complete model consists of two parts, the **topology** and the **parameters**. In order to support industrial deployment, the model format must be self-complete and must not expose any training source code.
-As a result, In PaddlePaddle, the **topology** represents as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model, we must support large size parameter, and efficient serialization/deserialization.
+As a result, In PaddlePaddle, the **topology** is represented as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model. We must support large size parameters and efficient serialization/deserialization of parameters.
## Implementation
-The topology is saved as a plain text, in detail, a self-contain protobuf file.
+The topology is saved as a plain text in a detailed self-contain protobuf file.
-The parameters are saved as a binary file. As we all know, the protobuf message has the limits of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We do a (benchmark experiment)[https://github.com/PaddlePaddle/Paddle/pull/4610], its result shows protobuf is not fit in this scene.
+The parameters are saved as a binary file. As we all know, the protobuf message has a limit of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We have done a [benchmark experiment](https://github.com/PaddlePaddle/Paddle/pull/4610), which shows that protobuf is not fit for the task.
-As a result, we design a particular format for tensor serialization. By default, arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of (LoDTensorDesc)[https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99]. We save the DescProto as the byte string header, it contains the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). Tensor stores value in a continuous memory buffer, for speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
+As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
-|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**|
+The table below shows a tensor's byte view in detail. Note that all the signed values are written in the little-endian format.
+
+|field name | type | description |
+| --- | --- | --- |
+| version | uint32_t | Version of saved file. Always 0 now. |
+| tensor desc length | uint32_t | TensorDesc(Protobuf message) length in bytes. |
+| tensor desc | void* | TensorDesc protobuf binary message |
+| tensor data | void* | Tensor's data in binary format. The length of `tensor_data` is decided by `TensorDesc.dims()` and `TensorDesc.data_type()` |
+| lod_level | uint64_t | Level of LoD |
+| length of lod[0] | uint64_t | [Optional] length of lod[0] in bytes. |
+| data of lod[0] | uint64_t* | [Optional] lod[0].data() |
+| ... | ... | ... |
-In detail, tensor's byte view as the table shows. Note that all the signed value written in little-endian.
-```text
-[offset] [type] [description]
-0004 4 bytes integer HeaderLength, the length of LoDTensorDesc
-0008 4 bytes integer ContentLength, the length of LodTensor Buffer
-0009 1 bytes char TensorDesc
-00010 1 bytes char TensorDesc
-...
-00100 1 bytes char TensorValue
-00101 1 bytes char TensorValue
-00102 1 bytes char TensorValue ..
-...
-```
## Summary
-We introduce the model format, the `ProgramDesc` describe the **topology**, and a bunch of particular format binary tensors describes the **parameters**.
+- We introduce a model format.
+- The model represented by its forward-pass computation procedure is saved in a **ProgramDesc** protobuf message.
+- A bunch of specified format binary tensors describe the **parameters**.
diff --git a/doc/design/regularization.md b/doc/design/regularization.md
index 703a9fbdd4392aa7f44733cce2da19caa1b51e4a..21280ac898feb4dd5e5a5d9e88d121e856850f0b 100644
--- a/doc/design/regularization.md
+++ b/doc/design/regularization.md
@@ -1,7 +1,7 @@
# Regularization in PaddlePaddle
## Introduction to Regularization
-A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. Many strategies are used by machine learning practitioners to reduce the test error, possibly at the expense of increased training error. These strategies are collectively known as **regularization**.
+A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. A frequently faced problem is the problem of **overfitting**, where the model does not make reliable predictions on new unseen data. **Regularization** is the process of introducing additional information in order to prevent overfitting. This is usually done by adding extra penalties to the loss function that restricts the parameter spaces that an optimization algorithm can explore.
### Parameter Norm Penalties
Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows:
@@ -18,52 +18,21 @@ The most commonly used norm penalties are the L2 norm penalty and the L1 norm pe
##### L1 Regularization

-A much more detailed mathematical background of reguilarization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
+A much more detailed mathematical background of regularization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
+## Regularization Survey
-## How to do Regularization in PaddlePaddle
-
-On surveying existing frameworks like Tensorflow, PyTorch, Caffe, etc, it can be seen that there are 2 common approaches of doing regularization:
-
-1. Making regularization a part of the optimizer using an attribute like `weight_decay` that is used to control the scale of the L2 Penalty. This approach is used in PyTorch as follows:
- ```python
- opt = torch.optim.SGD(params, lr=0.2, weight_decay=0.2)
- ```
- At every optimization step, this code will add the gradient of the L2 Norm of the params to the gradient of the params with respect to the loss function. This can seen in the following code snippet:
- ```python
- if weight_decay != 0:
- d_p.add_(weight_decay, p.data)
- ```
- This is a very restyrictive way of doing regularization and does not give the users enough flexibility.
-
- **Advantages**:
- - It is easy to implement for us.
- - Faster execution of backward. However, it can be done manually by advanced users too.
-
- **Disadvantages**:
- - Not flexible for other regularizations such as L1/L0 regularization.
- - Does not allow for different regularization coefficient for different parameters. For example, in most models, ony the weight matrices are regularized and the bias vectors are unregularized.
- - Tightly coupled optimizer and regularization implementation.
-
-
-2. Adding regularization ops to the graph through Python API. This approach is used by Tensorflow and Caffe. Using this approach, we manually add regularization ops to the graph and then add the regularization loss to the final loss function before sending them to the optimizer.
-
- **Advantages**:
- - Allows for greater flexibility to the users of Paddle. Using this approach, the users can put different regularization to different parameters and also choose parameters that are not a part of regularization.
- - Makes it easy for the users to customize and extend the framework.
-
- **Disadvantages**:
- - Implementation requires comprehensive design and time.
+A detailed survey of regularization in various deep learning frameworks can be found [here](https://github.com/PaddlePaddle/Paddle/wiki/Regularization-Survey).
## Proposal for Regularization in PaddlePaddle
### Low-Level implementation
-In the new design, we propose to create new operations for regularization. For now, we can add 2 ops thgat correspond to the most frequently used regularizations:
+In the new design, we propose to create new operations for regularization. For now, we can add 2 ops that correspond to the most frequently used regularizations:
- L2_regularization_op
- L1_regularization_op
-These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate Cpu and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes. other than L1 and L2 norm penalties.
+These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate CPU and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes other than L1 and L2 norm penalties.
The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API.
@@ -94,7 +63,7 @@ Since we want to create the regularization ops in a lazy manner, the regularizat
#### High-level API
-In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we lso need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
+In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we also need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go
index 739c4c01e02b10f46c36b997f8c4700150da2a26..f57db1c0a0107c4fd74b81aedaf4a58ff2a132ec 100644
--- a/go/cmd/master/master.go
+++ b/go/cmd/master/master.go
@@ -25,9 +25,8 @@ import (
"strings"
"time"
+ log "github.com/inconshreveable/log15"
"github.com/namsral/flag"
- log "github.com/sirupsen/logrus"
- "github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
@@ -41,16 +40,20 @@ func main() {
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info",
- "log level, possible values: debug, info, warning, error, fatal, panic")
+ "log level, possible values: debug, info, warn, error, crit")
flag.Parse()
- level, e := log.ParseLevel(*logLevel)
- candy.Must(e)
+ lvl, err := log.LvlFromString(*logLevel)
+ if err != nil {
+ panic(err)
+ }
- log.SetLevel(level)
+ log.Root().SetHandler(
+ log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
+ )
if *endpoints == "" {
- log.Warningln("-endpoints not set, fault tolerance not be enabled.")
+ log.Warn("-endpoints not set, fault tolerance not be enabled.")
}
var store master.Store
@@ -58,23 +61,25 @@ func main() {
eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP()
if err != nil {
- log.Fatal(err)
+ log.Crit("get external ip error", log.Ctx{"error": err})
+ panic(err)
}
addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil {
- log.Fatal(err)
+ log.Crit("error creating etcd client.", log.Ctx{"error": err})
+ panic(err)
}
} else {
store = &master.InMemStore{}
}
shutdown := func() {
- log.Infoln("shutting down gracefully")
+ log.Info("shutting down gracefully")
err := store.Shutdown()
if err != nil {
- log.Errorln(err)
+ log.Error("shutdown error", log.Ctx{"error": err})
}
}
@@ -86,24 +91,28 @@ func main() {
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil {
- log.Fatal(err)
+ log.Crit("error creating new service.", log.Ctx{"error": err})
+ panic(err)
}
err = rpc.Register(s)
if err != nil {
- log.Fatal(err)
+ log.Crit("error registering to etcd.", log.Ctx{"error": err})
+ panic(err)
}
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
- log.Fatal(err)
+ log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
+ panic(err)
}
go func() {
err = http.Serve(l, nil)
if err != nil {
- log.Fatal(err)
+ log.Crit("error serving HTTP", log.Ctx{"error": err})
+ panic(err)
}
}()
diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go
index bec5775d540729000ab2dd3002600f0a92619d70..1358801c1cf7f2e89f8e463560d25145d881d01d 100644
--- a/go/cmd/pserver/pserver.go
+++ b/go/cmd/pserver/pserver.go
@@ -27,11 +27,11 @@ import (
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
func main() {
- port := flag.Int("port", 0, "port of the pserver")
+ port := flag.Int("port", 8001, "port of the pserver")
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
@@ -41,13 +41,17 @@ func main() {
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
- "log level, possible values: debug, info, warning, error, fatal, panic")
+ "log level, possible values: debug, info, warn, error, crit")
flag.Parse()
- level, err := log.ParseLevel(*logLevel)
- candy.Must(err)
+ lvl, err := log.LvlFromString(*logLevel)
+ if err != nil {
+ panic(err)
+ }
- log.SetLevel(level)
+ log.Root().SetHandler(
+ log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
+ )
var idx int
@@ -63,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
- log.Infof("Could not find the pserver checkpoint.")
+ log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
@@ -71,10 +75,10 @@ func main() {
}
shutdown := func() {
- log.Infoln("shutting down gracefully")
+ log.Info("shutting down gracefully")
sErr := e.Shutdown()
if sErr != nil {
- log.Errorln(sErr)
+ log.Error("error shutting down", log.Ctx{"error": sErr})
}
}
@@ -95,7 +99,7 @@ func main() {
candy.Must(err)
go func() {
- log.Infof("start pserver at port %d", *port)
+ log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
diff --git a/go/glide.lock b/go/glide.lock
index aabc03657fff299581c61ed2a220e1c615cd6dfe..ce654d36364f8078a493651d8d8b141532eea26d 100644
--- a/go/glide.lock
+++ b/go/glide.lock
@@ -1,5 +1,5 @@
-hash: 328e7b9b7306b45e7b9879139a9f86698115981f6283032e1312093a6a6ddb04
-updated: 2017-10-16T08:00:23.484693528Z
+hash: 51d9e2e46d7fd9173ff11ecada40f7b7728756be18d5e2f032535f66465e6e15
+updated: 2017-10-24T15:04:09.987751592-07:00
imports:
- name: github.com/alecthomas/gometalinter
version: bae2f1293d092fd8167939d5108d1b025eaef9de
@@ -99,6 +99,8 @@ imports:
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
+- name: github.com/go-stack/stack
+ version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
@@ -120,8 +122,14 @@ imports:
- runtime
- runtime/internal
- utilities
+- name: github.com/inconshreveable/log15
+ version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
+- name: github.com/mattn/go-colorable
+ version: 5411d3eea5978e6cdc258b30de592b60df6aba96
+- name: github.com/mattn/go-isatty
+ version: 57fdcb988a5c543893cc61bce354a6e24ab70022
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
@@ -179,11 +187,12 @@ imports:
- lex/httplex
- trace
- name: golang.org/x/sys
- version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
+ version: e48874b42435b4347fc52bdee0424a52abc974d7
repo: https://github.com/golang/sys.git
vcs: git
subpackages:
- unix
+ - windows
- name: golang.org/x/text
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
@@ -222,4 +231,3 @@ testImports:
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
-
diff --git a/go/glide.yaml b/go/glide.yaml
index 4b22ab2caaae2272e3aab0eeba0758925c67d448..ba253f8bebef0ddab810a8303ab1fbe541defbdf 100644
--- a/go/glide.yaml
+++ b/go/glide.yaml
@@ -26,3 +26,7 @@ import:
version: v1.1.0
- package: github.com/alecthomas/gometalinter
version: v1.2.1
+- package: github.com/inconshreveable/log15
+ version: v2.13
+- package: github.com/go-stack/stack
+ version: v1.6.0
diff --git a/go/master/c/client.go b/go/master/c/client.go
index b5759c30b1d7f7dc33e162e959c7de165e02e1da..9a59337108d1aa33929abb480af686a96514655b 100644
--- a/go/master/c/client.go
+++ b/go/master/c/client.go
@@ -35,13 +35,19 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client
+func init() {
+ log.Root().SetHandler(
+ log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
+ )
+}
+
func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
@@ -117,7 +123,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
}
err := c.SetDataset(paths)
if err != nil {
- log.Errorln(err)
+ log.Error("error set dataset", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR
}
@@ -167,7 +173,7 @@ func paddle_request_save_model(client C.paddle_master_client, trainerID string,
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
- log.Errorln(err)
+ log.Error("error request save model", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR
}
diff --git a/go/master/client.go b/go/master/client.go
index f04cf50ce3cf765a79cbe555d3edb68f3dbb911e..5d657548c9039dfdacf61dd1145deb9777596d9f 100644
--- a/go/master/client.go
+++ b/go/master/client.go
@@ -21,7 +21,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
// Client is the client of the master server.
@@ -75,7 +75,7 @@ func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
for {
err := f()
if err != nil {
- log.Warningln(err)
+ log.Warn("create etcd client error", log.Ctx{"error": err})
} else {
break
}
@@ -135,13 +135,13 @@ func (c *Client) getRecords(passID int) {
time.Sleep(time.Second * 3)
continue
}
- log.Errorf("getTask error: %s", err)
+ log.Error("getTask error.", log.Ctx{"error": err})
}
for _, chunk := range t.Chunks {
f, e := os.Open(chunk.Path)
if e != nil {
- log.Errorln(e)
+ log.Error("error open chunk", log.Ctx{"error": e})
continue
}
@@ -152,12 +152,15 @@ func (c *Client) getRecords(passID int) {
if s.Err() != nil {
c.ch <- record{nil, s.Err()}
- log.Errorln(err, chunk.Path)
+ log.Error(
+ "error scan chunk",
+ log.Ctx{"error": err, "path": chunk.Path},
+ )
}
err = f.Close()
if err != nil {
- log.Errorln(err)
+ log.Error("error close record file", log.Ctx{"error": err})
}
}
@@ -166,7 +169,7 @@ func (c *Client) getRecords(passID int) {
// correct, but a reasonable approximation.
err = c.taskFinished(t.Meta.ID)
if err != nil {
- log.Errorln(err)
+ log.Error("task finish callback error.", log.Ctx{"error": err})
}
}
}
@@ -179,12 +182,12 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
- log.Errorln(err)
+ log.Error("close old master addr error", log.Ctx{"error": err})
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
- log.Errorln(err)
+ log.Error("connect to new master addr error", log.Ctx{"error": err})
// connect to addr failed, set
// to last known addr in order
diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go
index d5f3d79464655540a29eaa6395057aa5795c4615..2f13fd0dcda85ee10669133ed011f47ce418b61c 100644
--- a/go/master/client_internal_test.go
+++ b/go/master/client_internal_test.go
@@ -25,8 +25,6 @@ import (
"testing"
"time"
- log "github.com/sirupsen/logrus"
-
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
@@ -36,10 +34,6 @@ const (
chunkPerTask = 10
)
-func init() {
- log.SetLevel(log.ErrorLevel)
-}
-
func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"
diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go
index 94848d887e8bc4b055a7c8b89b9b7f26a39229d1..2a41d36949cb19d9076c0ed00c8db6e235f1296c 100644
--- a/go/master/etcd_client.go
+++ b/go/master/etcd_client.go
@@ -20,7 +20,7 @@ import (
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
const (
@@ -44,7 +44,7 @@ type EtcdClient struct {
// NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
- log.Debugf("Connecting to etcd at %v", endpoints)
+ log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: dialTimeout,
@@ -64,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
- log.Infof("Trying to acquire lock at %s.", lockPath)
+ log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
err = lock.Lock(context.TODO())
if err != nil {
return nil, err
}
- log.Infof("Successfully acquired lock at %s.", lockPath)
+ log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
@@ -78,7 +78,8 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
}
if !resp.Succeeded {
- log.Fatal("No longer owns the master lock. Exiting.")
+ log.Crit("No longer owns the master lock. Exiting.")
+ panic("No longer owns the master lock. Exiting.")
}
e := &EtcdClient{
@@ -102,7 +103,7 @@ func (e *EtcdClient) Save(state []byte) error {
}
if !resp.Succeeded {
- log.Errorln("No longer owns the lock, trying to lock again")
+ log.Error("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx)
cancel()
@@ -116,9 +117,10 @@ func (e *EtcdClient) Save(state []byte) error {
// to kill current master server. The current
// state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry.
- log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err)
+ log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
+ panic("Could not acquire the lock at %s: %v. Exiting.")
}
- log.Infof("Successfully acquired lock at %s.", e.lockPath)
+ log.Info("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state)
}
@@ -136,7 +138,7 @@ func (e *EtcdClient) Load() ([]byte, error) {
}
if !resp.Succeeded {
- log.Errorln("No longer owns the lock, trying to lock and load again.")
+ log.Error("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background())
if err != nil {
return nil, err
@@ -163,7 +165,7 @@ func (e *EtcdClient) Shutdown() error {
if err == nil {
err = newErr
} else {
- log.Errorln(newErr)
+ log.Error("shutdown error", log.Ctx{"error": newErr})
}
}
@@ -192,7 +194,7 @@ func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
- log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
+ log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
valChan <- string(ev.Kv.Value)
}
}
diff --git a/go/master/service.go b/go/master/service.go
index df7c6860e6ae13a5be7d0425273812208685ee9d..f3501028800c850a521d4b08db323cb70fe926d2 100644
--- a/go/master/service.go
+++ b/go/master/service.go
@@ -25,7 +25,7 @@ import (
"sync"
"time"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
"github.com/PaddlePaddle/recordio"
)
@@ -170,11 +170,11 @@ func (s *Service) recover() (bool, error) {
}
if state == nil {
- log.Infoln("No state exists, not recovered.")
+ log.Info("No state exists, not recovered.")
return false, nil
}
- log.Infof("Loaded snapshot of size: %d bytes.", len(state))
+ log.Info("Loaded snapshot.", log.Ctx{"size": len(state)})
gr, err := gzip.NewReader(bytes.NewReader(state))
if err != nil {
return false, err
@@ -191,11 +191,11 @@ func (s *Service) recover() (bool, error) {
if err != nil {
// Only close failed, recover actually succeed, so
// just log error.
- log.Errorln(err)
+ log.Error("error close recover file.", log.Ctx{"error": err})
}
s.state = tqs
- log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.")
+ log.Info("Master recovered from snapshot, scheduling pending task timeout check.", s.logCtx())
for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
}
@@ -224,7 +224,7 @@ func (s *Service) snapshot() error {
}
state := buf.Bytes()
- log.Infof("Saving snapshot of size: %d bytes.", len(state))
+ log.Info("Saving snapshot.", log.Ctx{"size bytes": len(state)})
return s.store.Save(state)
}
@@ -260,7 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
}
count := index.NumChunks()
- log.Infof("readChunks: file %s has %d chunks", path, count)
+ log.Info("reading chunks.", log.Ctx{"path": path, "num chunks": count})
for i := 0; i < count; i++ {
chunk := Chunk{
Path: path,
@@ -300,7 +300,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
err = s.snapshot()
if err != nil {
- log.Errorln(err)
+ log.Error("snapshot error", log.Ctx{"error": err})
return err
}
close(s.ready)
@@ -320,7 +320,7 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
defer func() {
err := s.snapshot()
if err != nil {
- log.Errorln(err)
+ log.Error("snapshot error", log.Ctx{"error": err})
}
}()
@@ -328,12 +328,12 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
t.NumFailure++
if t.NumFailure > s.failureMax {
- log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
+ log.Warn("Task failed to many times, discard.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Failed = append(s.state.Failed, t)
return
}
- log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
+ log.Warn("Task failed, re-dispatch.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Todo = append(s.state.Todo, t)
return
}
@@ -353,8 +353,8 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
}
// must be called with lock held.
-func (s *Service) logFields() log.Fields {
- return log.Fields{
+func (s *Service) logCtx() log.Ctx {
+ return log.Ctx{
"todoLen": len(s.state.Todo),
"pendingLen": len(s.state.Pending),
"doneLen": len(s.state.Done),
@@ -383,10 +383,10 @@ func (s *Service) GetTask(passID int, task *Task) error {
if len(s.state.Todo) == 0 {
if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
- log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
+ log.Warn("All tasks failed, may start next pass", s.logCtx())
return ErrAllTaskFailed
}
- log.WithFields(s.logFields()).Warningln("No more available task.")
+ log.Warn("No more available task.", s.logCtx())
return ErrNoMoreAvailable
}
@@ -400,8 +400,9 @@ func (s *Service) GetTask(passID int, task *Task) error {
}
*task = t.Task
- log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)
-
+ ctx := s.logCtx()
+ ctx["task meta"] = t.Task.Meta
+ log.Info("Task dispatched.", ctx)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil
}
@@ -417,7 +418,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.state.Pending[taskID]
if !ok {
- log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
+ ctx := s.logCtx()
+ ctx["task id"] = taskID
+ log.Warn("Pending task not found.", ctx)
return nil
}
@@ -426,7 +429,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = append(s.state.Done, t)
delete(s.state.Pending, taskID)
- log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
+ ctx := s.logCtx()
+ ctx["task id"] = taskID
+ log.Info("Task finished.", ctx)
if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
// increase master side pass count if all tasks finished
s.state.CurPass++
@@ -434,12 +439,14 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.state.Failed = []taskEntry{}
- log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass)
+ ctx := s.logCtx()
+ ctx["new pass"] = s.state.CurPass
+ log.Warn("all task finished, add new pass data.", ctx)
}
err := s.snapshot()
if err != nil {
- log.Errorln(err)
+ log.Error("snapshot error", log.Ctx{"error": err})
}
return err
}
@@ -455,7 +462,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
t, ok := s.state.Pending[meta.ID]
if !ok {
- log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
+ log.Warn("TaskFailed:Pending task not found.", log.Ctx{"task": t.Task.Meta})
return nil
}
diff --git a/go/pserver/client/c/cclient.go b/go/pserver/client/c/cclient.go
index a49cd01522b8b49a74f21fcb97e9eeb1fbb2d272..2eeec1b6b3c28556e02780e40ae5d6b693dce484 100644
--- a/go/pserver/client/c/cclient.go
+++ b/go/pserver/client/c/cclient.go
@@ -45,9 +45,15 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
+func init() {
+ log.Root().SetHandler(
+ log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
+ )
+}
+
var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client
@@ -164,10 +170,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
- log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name)
+ log.Warn(
+ "parameter already initialized, treat paddle_init_param as successful.",
+ log.Ctx{"parameter": name},
+ )
return C.PSERVER_OK
}
- log.Errorln(err)
+ log.Error("error init param", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
@@ -180,11 +189,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams()
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
- log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.")
+ log.Warn("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK
}
- log.Errorln(err)
+ log.Error("error finish init params", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
@@ -205,7 +214,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient
c := get(client)
err := c.SendGrads(gs)
if err != nil {
- log.Errorln(err)
+ log.Error("error send grads", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
@@ -222,7 +231,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
c := get(client)
ps, err := c.GetParams(ns)
if err != nil {
- log.Errorln(err)
+ log.Error("error get params", log.Ctx{"error": err})
return C.PSERVER_ERROR
}
@@ -231,7 +240,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps {
pn[i] = p.Name
}
- log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
+ log.Error(
+ "pserver returned wrong number of parameters.",
+ log.Ctx{
+ "Requested": strings.Join(pn, ", "),
+ "Returned": strings.Join(ns, ", "),
+ },
+ )
return C.PSERVER_ERROR
}
@@ -241,7 +256,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps {
pn[i] = p.Name
}
- log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
+ log.Error(
+ "pserver returned wrong parameters, or not in requested order.",
+ log.Ctx{
+ "Requested": strings.Join(pn, ", "),
+ "Returned": strings.Join(ns, ", "),
+ },
+ )
return C.PSERVER_ERROR
}
}
@@ -251,13 +272,19 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nil {
- log.Errorln("must pre-allocate parameter.")
+ log.Error("must pre-allocate parameter.")
return C.PSERVER_ERROR
}
if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) {
- log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
+ log.Error(
+ "the pre-allocated content len does not match parameter content len.",
+ log.Ctx{
+ "Pre-allocated len": param.content_len,
+ "Returned len": len(p.Content),
+ },
+ )
return C.PSERVER_ERROR
}
}
diff --git a/go/pserver/client/client.go b/go/pserver/client/client.go
index e5187ce3df77cb983e070508230c51c078f1e07b..18fce34b376a8f60900700c588e30f92ef3514ed 100644
--- a/go/pserver/client/client.go
+++ b/go/pserver/client/client.go
@@ -22,7 +22,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
// TODO(helin): add RPC call retry logic
@@ -84,7 +84,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
if curServers[i].Addr == "" {
err := c.pservers[i].Close()
if err != nil {
- log.Errorln(err)
+ log.Error("error closing connection to pserver", log.Ctx{"error": err})
}
continue
@@ -92,7 +92,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil {
- log.Errorln(err)
+ log.Error("error connecting to pserver", log.Ctx{"error": err})
// connect to addr failed, set
// to last known addr in order
diff --git a/go/pserver/client/client_test.go b/go/pserver/client/client_test.go
index c3d88e926d7cb5f3027be26a270bee6f2db65f31..ec832305ee8e24967b06b6b621c44cde30c09e55 100644
--- a/go/pserver/client/client_test.go
+++ b/go/pserver/client/client_test.go
@@ -30,7 +30,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
const (
@@ -90,7 +90,7 @@ func initEtcdClient() {
DialTimeout: time.Second * time.Duration(1),
})
if err != nil {
- log.Errorf("err %v", err)
+ log.Error("error init etcd client", log.Ctx{"error": err})
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err = client.Delete(ctx, pserver.PsDesired)
diff --git a/go/pserver/client/etcd_client.go b/go/pserver/client/etcd_client.go
index f9071caaa8f5ac32d426b1d4344a30262202b96d..16d0c3b943050f05c54a3e010054fd7c2f33b6d6 100644
--- a/go/pserver/client/etcd_client.go
+++ b/go/pserver/client/etcd_client.go
@@ -25,7 +25,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
const (
@@ -54,26 +54,29 @@ func (e *Etcd) Desired() int {
resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
- log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
+ log.Error(
+ "Get ps dresire number failed! reconnecting...",
+ log.Ctx{"error": err},
+ )
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
- log.Infoln("Waiting for ps desired registered ...")
+ log.Info("Waiting for ps desired registered ...")
time.Sleep(e.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
- log.Errorf("psDesired %d invalid %v", psDesired, err)
+ log.Error("atoi failed", log.Ctx{"error": err})
time.Sleep(e.timeout)
continue
}
- log.Debugf("Get psDesired number: %d", psDesired)
+ log.Debug("Got psDesired", log.Ctx{"psDesired": psDesired})
break
}
return psDesired
@@ -88,17 +91,20 @@ func (e *Etcd) List() []Server {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i)
- log.Debugf("checking %s", psKey)
+ log.Debug("looking for pserver", log.Ctx{"ps key": psKey})
resp, err := e.client.Get(ctx, psKey)
cancel()
if err != nil {
- log.Infof("Get psKey= %s error, %v", psKey, err)
+ log.Info(
+ "Get psKey error",
+ log.Ctx{"ps key": psKey, "error": err},
+ )
time.Sleep(e.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
- log.Infof("Waiting for ps addr registered ...")
+ log.Info("Waiting for ps addr registered ...")
time.Sleep(e.timeout)
continue
}
@@ -106,11 +112,17 @@ func (e *Etcd) List() []Server {
psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address
if psAddr == "" {
- log.Infof("Get psKey = %s, psAddr is empty", psKey)
+ log.Info(
+ "Value under psKey is empty",
+ log.Ctx{"psKey": psKey},
+ )
time.Sleep(e.timeout)
continue
}
- log.Debugf("got value (%s) for key: %s", psAddr, psKey)
+ log.Debug(
+ "got psAddr given psKey",
+ log.Ctx{"psAddr": psAddr, "psKey": psKey},
+ )
servers[i].Index = i
servers[i].Addr = psAddr
}
@@ -130,13 +142,13 @@ func NewEtcd(endpoints string) *Etcd {
DialTimeout: defaultEtcdTimeout,
})
if err != nil {
- log.Errorf("Init etcd connection failed: %v", err)
+ log.Error("Init etcd connection failed", log.Ctx{"error": err})
time.Sleep(defaultEtcdTimeout)
continue
}
break
}
- log.Infof("Connected to etcd: %s\n", endpoints)
+ log.Info("Connected to etcd endpoint", log.Ctx{"endpoint": endpoints})
client := &Etcd{
client: cli,
timeout: defaultEtcdTimeout,
@@ -154,7 +166,7 @@ func (e *Etcd) Select() (bool, error) {
}
lock := concurrency.NewMutex(sess, initLockPath)
- log.Infof("Trying to acquire lock at %s.", initLockPath)
+ log.Info("Trying to acquire lock", log.Ctx{"lock path": initLockPath})
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
@@ -162,7 +174,7 @@ func (e *Etcd) Select() (bool, error) {
if err != nil {
return false, err
}
- log.Infof("Successfully acquired lock at %s.", initLockPath)
+ log.Info("Successfully acquired lock", log.Ctx{"lock path": initLockPath})
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
@@ -181,17 +193,17 @@ func (e *Etcd) Select() (bool, error) {
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
- log.Infoln("Trainer selected.")
+ log.Info("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
- log.Infoln("Initialization is already done.")
+ log.Info("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
- log.Errorln(err)
+ log.Error("error unlocking", log.Ctx{"error": err})
}
return false, nil
}
@@ -221,7 +233,7 @@ func (e *Etcd) Done() error {
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
- log.Errorln(err)
+ log.Error("error unlocking", log.Ctx{"error": err})
} else {
e.lock = nil
}
@@ -244,7 +256,7 @@ func (e *Etcd) Close() error {
cErr := e.client.Close()
if cErr != nil {
if err != nil {
- log.Errorln(cErr)
+ log.Error("error closing etcd client", log.Ctx{"error": cErr})
return err
}
return cErr
diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go
index 41f0640fc09a3265c0e11c06255c7ee834983203..08ddb247f26379da80d485b1a6059f793864b786 100644
--- a/go/pserver/etcd_client.go
+++ b/go/pserver/etcd_client.go
@@ -24,7 +24,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
const (
@@ -82,19 +82,19 @@ func (e *EtcdClient) Register(port int) (int, error) {
DialTimeout: e.dialTimeout,
})
if err != nil {
- log.Errorf("connect to etcd error: %v", err)
+ log.Error("connect to etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout)
continue
}
e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil {
- log.Errorf("create etcd session error: %v", err)
+ log.Error("create etcd session error", log.Ctx{"error": err})
time.Sleep(retryTimeout)
continue
}
e.sess = sess
- log.Debugf("inited client to %s", e.endpoints)
+ log.Debug("connected to etcd", log.Ctx{"endpoint": e.endpoints})
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
@@ -104,7 +104,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
_, err := e.initDesiredPservers(ctx, e.numPservers)
cancel()
if err != nil {
- log.Warn(err)
+ log.Warn("pserver init error", log.Ctx{"error": err, "num pservers": e.numPservers})
time.Sleep(retryTimeout)
continue
}
@@ -119,14 +119,17 @@ func (e *EtcdClient) Register(port int) (int, error) {
resp, err := e.client.Get(ctx, PsDesired)
cancel()
if err != nil {
- log.Errorf("getting %s error: %v", PsDesired, err)
+ log.Error("get etcd key error", log.Ctx{"key": PsDesired, "error": err})
time.Sleep(retryTimeout)
continue
}
if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
- log.Errorf("value of %s invalid %v\n", PsDesired, err)
+ log.Error(
+ "psDesired atoi error",
+ log.Ctx{"error": err, "value": string(resp.Kvs[0].Value)},
+ )
time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change
continue
@@ -143,7 +146,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel()
if err != nil {
- log.Warn(err)
+ log.Warn("register pserver on etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout)
continue
}
@@ -170,16 +173,17 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
registered := false
for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i)
- log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
- log.Debugf("got value (%s) for key: %s", ps, psKey)
+ log.Debug(
+ "register pserver got value",
+ log.Ctx{"value": ps, "key": psKey},
+ )
if ps == "" {
// find the first id and write info
pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
- log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
- log.Debug("register finished")
+ log.Debug("register finished", log.Ctx{"key": psKey, "value": pserverAddr})
idx = i
registered = true
break
@@ -239,7 +243,7 @@ func (e *EtcdClient) Shutdown() error {
newErr := e.client.Close()
if newErr != nil {
if err != nil {
- log.Errorln(newErr)
+ log.Error("shutdown error", log.Ctx{"error": newErr})
} else {
err = newErr
}
diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go
index 51ffba5c74d82b7f24d5ab6218e47479c4d18658..6d28cad25a79d713dc06b72f96087a6b723453cd 100644
--- a/go/pserver/optimizer.go
+++ b/go/pserver/optimizer.go
@@ -25,7 +25,7 @@ import (
"fmt"
"unsafe"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
type optimizer struct {
@@ -56,12 +56,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
c := paramWithConfigs.Config
s := State
paramBufferSize := C.size_t(len(p.Content))
- log.WithFields(log.Fields{
+ log.Info("New Optimizer Created with config", log.Ctx{
"ElementType": p.ElementType,
"ParamSize": paramBufferSize,
"ConfigSize": len(c),
"StateSize": len(s),
- }).Info("New Optimizer Created with config:")
+ })
var cbuffer unsafe.Pointer
cbuffer = C.malloc(paramBufferSize)
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0])
}
+ var cptr (*C.uchar)
+ if len(c) > 0 {
+ cptr = (*C.uchar)(&c[0])
+ } else {
+ log.Error("empty config", "param name", paramWithConfigs.Param.Name)
+ }
o.config = c
o.opt = C.paddle_create_optimizer(
- (*C.uchar)(&c[0]),
+ cptr,
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
diff --git a/go/pserver/service.go b/go/pserver/service.go
index 29e953acdd8ae16d13cf2307e212f8a18f0f2190..f703d99a29ae9f5310ef36a7492b729c4c892937 100644
--- a/go/pserver/service.go
+++ b/go/pserver/service.go
@@ -17,12 +17,11 @@ package pserver
import (
"bufio"
"bytes"
- "crypto/md5"
"encoding/gob"
- "encoding/hex"
"encoding/json"
"errors"
"fmt"
+ "hash/crc32"
"io/ioutil"
"os"
"path"
@@ -32,7 +31,7 @@ import (
uuid "github.com/satori/go.uuid"
- log "github.com/sirupsen/logrus"
+ log "github.com/inconshreveable/log15"
)
// ElementType is the type of elements of a Parameter.
@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
-var ErrCheckpointNotFound = errors.New("checkpoint not found")
+var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
// RPC error message.
const (
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
- MD5 string `json:"md5"`
+ CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"`
}
@@ -92,7 +91,7 @@ type Service struct {
idx int
checkpointInterval time.Duration
checkpointPath string
- client *EtcdClient
+ client KVStore
mu sync.Mutex
optMap map[string]*optimizer
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State []byte
}
-func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
+type KVStore interface {
+ GetKey(key string, timeout time.Duration) ([]byte, error)
+ PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
+}
+
+func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
@@ -123,7 +127,10 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}
// LoadCheckpoint loads checkpoint from file.
-func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
+func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
+ log.Info("Loading checkpoint", "pserver index", idx)
+ defer traceTime(time.Now(), "load checkpoint")
+
cpMeta, err := loadMeta(e, idx)
if err != nil {
return nil, err
@@ -134,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return nil, err
}
- // TODO(helin): change MD5 to CRC since CRC is better for file
- // checksum in our use case (emphasize speed over security).
- h := md5.New()
- md5 := hex.EncodeToString(h.Sum(content))
- if md5 != cpMeta.MD5 {
+ crc32 := crc32.ChecksumIEEE(content)
+ if crc32 != cpMeta.CRC32 {
return nil, errors.New(WrongChecksum)
}
@@ -147,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if err = dec.Decode(&cp); err != nil {
return nil, err
}
+
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
-func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
+func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
@@ -170,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
+ close(s.initialized)
}
return s, nil
}
@@ -178,6 +184,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select {
case <-s.initialized:
+ log.Warn("init param called but parameters already initialized.")
return errors.New(AlreadyInitialized)
default:
}
@@ -191,6 +198,13 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
+ log.Info(
+ "init parameter",
+ "name", paramWithConfigs.Param.Name,
+ "config len", len(paramWithConfigs.Config),
+ "param len", len(paramWithConfigs.Param.Content),
+ "type", paramWithConfigs.Param.ElementType,
+ )
return nil
}
@@ -199,6 +213,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error
func (s *Service) FinishInitParams(_ int, _ *int) error {
select {
case <-s.initialized:
+ log.Warn("finished init param called but parameters already initialized.")
return errors.New(AlreadyInitialized)
default:
}
@@ -209,10 +224,12 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t {
err := s.checkpoint()
if err != nil {
- log.Errorln(err)
+ log.Error("checkpoint error", log.Ctx{"error": err})
}
}
}()
+
+ log.Info("init parameter finished.")
return nil
}
@@ -222,6 +239,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
select {
case <-s.initialized:
default:
+ log.Warn("received gradient before initialization.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
return errors.New(Uninitialized)
}
@@ -233,6 +251,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
return fmt.Errorf("parameter: %s does not exist", g.Name)
}
+ log.Info("received gradient from trainer, updating gradient.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
return o.UpdateParameter(g)
}
@@ -244,6 +263,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
opt, ok := s.optMap[name]
if !ok {
+ log.Warn("trainer wants to get a parameter that does not exist.", "name", name)
return fmt.Errorf("parameter: %s does not exist", name)
}
@@ -257,12 +277,14 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
+
+ log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil
}
func traceTime(start time.Time, name string) {
elapsed := time.Since(start)
- log.Infof("%s took %v", name, elapsed)
+ log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed})
}
// checkpoint saves checkpoint to disk.
@@ -270,7 +292,7 @@ func traceTime(start time.Time, name string) {
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
- log.Infoln("Begin save checkpoint.")
+ log.Info("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock()
@@ -315,7 +337,7 @@ func (s *Service) checkpoint() (err error) {
closeErr := f.Close()
if closeErr != nil {
if err != nil {
- log.Errorln(closeErr)
+ log.Error("error close checkpoint file", log.Ctx{"error": closeErr})
} else {
// Set closeErr as return value.
err = closeErr
@@ -336,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
- log.Infoln("Do not have existing checkpoint.")
+ log.Info("old meta not found, skip removing old meta")
err = nil
+ } else if err == nil {
+ log.Info("removing old meta")
+ if oldMeta.Path != "" {
+ rmErr := os.Remove(oldMeta.Path)
+ if rmErr != nil {
+ // log error, but still treat checkpoint as
+ // successful.
+ log.Error("remove old meta file error", log.Ctx{"error": rmErr})
+ }
+ }
}
if err != nil {
return
}
- h := md5.New()
- md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
+ crc32 := crc32.ChecksumIEEE(buf.Bytes())
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
- MD5: md5,
+ CRC32: crc32,
Path: p,
}
@@ -363,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
}
- if oldMeta.Path != "" {
- rmErr := os.Remove(oldMeta.Path)
- if rmErr != nil {
- // log error, but still treat checkpoint as
- // successful.
- log.Errorln(rmErr)
- }
- }
-
return
}
diff --git a/go/pserver/service_internal_test.go b/go/pserver/service_internal_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..36eca5112b3117cf295288de0de957c4af040f03
--- /dev/null
+++ b/go/pserver/service_internal_test.go
@@ -0,0 +1,86 @@
+package pserver
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+const testDir = "./test_data"
+
+type myKV struct {
+ m map[string][]byte
+}
+
+func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
+ if m.m == nil {
+ m.m = make(map[string][]byte)
+ }
+ return m.m[key], nil
+}
+
+func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
+ if m.m == nil {
+ m.m = make(map[string][]byte)
+ }
+ m.m[key] = value
+ return nil
+}
+
+func TestCheckpoint(t *testing.T) {
+ kv := &myKV{}
+ s, err := NewService(0, time.Hour, testDir, kv, nil)
+ assert.Nil(t, err)
+ err = s.checkpoint()
+ assert.Nil(t, err)
+ _, err = LoadCheckpoint(kv, 0)
+ assert.Nil(t, err)
+}
+
+func float32ToByte(f float32) []byte {
+ var buf bytes.Buffer
+ err := binary.Write(&buf, binary.LittleEndian, f)
+ if err != nil {
+ fmt.Println("binary.Write failed:", err)
+ }
+ return buf.Bytes()
+}
+
+func TestCheckpointWithData(t *testing.T) {
+ kv := &myKV{}
+ s, err := NewService(0, time.Hour, testDir, kv, nil)
+ assert.Nil(t, err)
+
+ var content []byte
+ for i := 0; i < 50000; i++ {
+ content = append(content, float32ToByte(float32(i))...)
+ }
+
+ p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
+ err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
+ assert.Nil(t, err)
+
+ err = s.FinishInitParams(0, nil)
+ assert.Nil(t, err)
+
+ var p2 Parameter
+ err = s.GetParam(p1.Name, &p2)
+ assert.Nil(t, err)
+ assert.Equal(t, p1, p2)
+
+ err = s.checkpoint()
+ assert.Nil(t, err)
+ cp, err := LoadCheckpoint(kv, 0)
+ assert.Nil(t, err)
+ s1, err := NewService(0, time.Hour, testDir, kv, cp)
+ assert.Nil(t, err)
+
+ var p3 Parameter
+ err = s1.GetParam(p1.Name, &p3)
+ assert.Nil(t, err)
+ assert.Equal(t, p1, p3)
+}
diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go
index be648cd1e83e4f7790edac5842db432fb4870072..b6f4566eb78cf797e3738afa5f86f5c4e8090d85 100644
--- a/go/pserver/service_test.go
+++ b/go/pserver/service_test.go
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait()
}
-
-func TestCheckpointSpeed(t *testing.T) {
- //TODO(zhihong): test speed
-}
diff --git a/paddle/capi/gradient_machine.cpp b/paddle/capi/gradient_machine.cpp
index 629449bbd497a7444144c533ad079b3ae6b51438..482b51e8a8430863c3e13df2298f6979d3959461 100644
--- a/paddle/capi/gradient_machine.cpp
+++ b/paddle/capi/gradient_machine.cpp
@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config;
+ paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
- return kPD_PROTOBUF_ERROR;
+ if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
+ !modelConfig.IsInitialized()) {
+ return kPD_PROTOBUF_ERROR;
+ }
+ } else {
+ modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create(
- config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
+ modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector& parameters = ptr->machine->getParameters();
for (auto& para : parameters) {
para->load(is);
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index 85374a476d51dc4c0e22793e8b53d6d7ba21c8da..0d1617424ecffdcdaaccba6cbd761b2563f6b073 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -1,6 +1,5 @@
# ddim lib
proto_library(framework_proto SRCS framework.proto)
-proto_library(saver_proto SRCS framework.proto saver.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
@@ -10,7 +9,7 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
-cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor saver_proto framework_proto)
+cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
@@ -27,7 +26,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
-cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator)
+cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
@@ -43,7 +42,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op)
-cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
+cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)
diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc
index 1ae7fb60f01e4925ceb310f661171eb231eb6c96..150c152367e1bcdc095bce6f77fafdef601e1c47 100644
--- a/paddle/framework/backward.cc
+++ b/paddle/framework/backward.cc
@@ -315,6 +315,7 @@ static void CreateGradVarInBlock(
return false; /* not break */
});
if (need_infer_shape) {
+ ops[op_index]->InferVarType(block_desc);
ops[op_index]->InferShape(*block_desc);
}
}
@@ -452,11 +453,16 @@ ParamGradInfoMap AppendBackward(
std::transform(target_shape_desc.begin(), target_shape_desc.end(),
std::back_inserter(target_shape),
[](int64_t dim) { return static_cast(dim); });
+ VLOG(3) << "backward from loss=" << target.Name()
+ << " data_type=" << target.GetDataType();
std::unique_ptr fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", target_shape},
{"value", static_cast(1.0)},
- {"data_type", framework::DataType::FP32}}));
+ {"data_type", target.GetDataType()}}));
+ // infer var type of fill_one_op
+ fill_one_op->InferVarType(root_block);
+
root_block->AppendAllocatedOp(std::move(fill_one_op));
size_t forward_op_num = root_block->OpSize();
size_t forward_block_num = program_desc.Size();
@@ -475,8 +481,7 @@ ParamGradInfoMap AppendBackward(
std::unordered_map retv;
auto var = root_block->Var(fill_one_op_out);
- // FIXME(qiao) infer the data type
- var->SetDataType(framework::DataType::FP32);
+ var->SetDataType(target.GetDataType());
var->SetShape(target.Shape());
auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out;
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 10301f7e39423c8ff0eba33277edecab14c119bf..421f1321948235aa0c1acd2e24037b34716e449a 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -21,6 +21,8 @@
#include "paddle/framework/var_desc.h"
#include "paddle/operators/net_op.h"
+USE_OP(fill_constant);
+
namespace paddle {
namespace framework {
diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc
index 251e340e6ddcc17ba16bdcab63f2a8c907122eab..b73a20cc89d936c2beee6a39cdf71cda3915bcdc 100644
--- a/paddle/framework/block_desc.cc
+++ b/paddle/framework/block_desc.cc
@@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
+
+BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
+ : prog_(prog), desc_(desc), need_update_(false) {
+ for (const VarDesc &var_desc : desc_->vars()) {
+ vars_[var_desc.name()].reset(new VarDescBind(var_desc));
+ }
+ for (const OpDesc &op_desc : desc_->ops()) {
+ ops_.emplace_back(new OpDescBind(op_desc, prog));
+ }
+}
+
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h
index c685050850dc25f346df49b5ce1d897974870460..72f77a88a24434fd7d2ed685ac850c88888d6808 100644
--- a/paddle/framework/block_desc.h
+++ b/paddle/framework/block_desc.h
@@ -36,8 +36,7 @@ class ProgramDescBind;
class BlockDescBind {
public:
- BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
- : prog_(prog), desc_(desc), need_update_(false) {}
+ BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog);
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index c25a62c2b11ead614d93a4be8d63d40d0cc0165a..bafb4fbd480bf2a28e3aa3dc615a310f80cec493 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -15,6 +15,7 @@
#pragma once
#include
#include "paddle/framework/framework.pb.h"
+#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h
index 357ad21f39f3b1f6dbdb98063f8fb24ec6800ec6..b731840ef2a4b2d5d82b019d28ad6517fa4b7607 100644
--- a/paddle/framework/details/op_registry.h
+++ b/paddle/framework/details/op_registry.h
@@ -28,7 +28,8 @@ enum OpInfoFillType {
kOperator = 0,
kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2,
- kVarTypeInference = 3
+ kVarTypeInference = 3,
+ kShapeInference = 4
};
template
@@ -42,7 +43,10 @@ struct OpInfoFillTypeID {
? kGradOpDescMaker
: (std::is_base_of::value
? kVarTypeInference
- : static_cast(-1))));
+ : (std::is_base_of::value
+ ? kShapeInference
+ : static_cast(
+ -1)))));
}
};
@@ -121,6 +125,16 @@ struct OpInfoFiller {
}
};
+template
+struct OpInfoFiller {
+ void operator()(const char* op_type, OpInfo* info) const {
+ info->infer_shape_ = [](InferShapeContext* ctx) {
+ T inference;
+ inference(ctx);
+ };
+ }
+};
+
} // namespace details
} // namespace framework
diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc
index 1f1e4edda823d62b169422672c855d96a2bd2ede..3e9d8b3084e8a76f3d5b8367b0ec45ed74dec42f 100644
--- a/paddle/framework/executor.cc
+++ b/paddle/framework/executor.cc
@@ -20,6 +20,7 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"
@@ -56,6 +57,22 @@ Executor::~Executor() {
}
}
+static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
+ if (var_type == VarDesc::LOD_TENSOR) {
+ var->GetMutable();
+ } else if (var_type == VarDesc::SELECTED_ROWS) {
+ var->GetMutable();
+ } else if (var_type == VarDesc::FEED_MINIBATCH) {
+ var->GetMutable();
+ } else if (var_type == VarDesc::FETCH_LIST) {
+ var->GetMutable();
+ } else {
+ PADDLE_THROW(
+ "Variable type must be "
+ "LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST.");
+ }
+}
+
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
@@ -69,10 +86,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
for (auto& var : block.vars()) {
if (var.persistable()) {
auto* ptr = scope->Var(var.name());
+ CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope.Var(var.name());
+ CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " locally, which pointer is " << ptr;
}
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index f53dd1c1858b45d39692eb683bc1dd9ee75b88fb..584308a5388da0d02d29f71a28097b02b6ea825f 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -13,7 +13,6 @@
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
-#include "paddle/framework/saver.pb.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
@@ -106,6 +105,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const {
return lod_[level][idx + 1] - lod_[level][idx];
}
+size_t LoDTensor::NumInstancesInElement(size_t level, size_t idx) const {
+ PADDLE_ENFORCE_LT(level, NumLevels());
+ PADDLE_ENFORCE_LT(idx, NumElements(level));
+ auto abs_lod = ToAbsOffset(lod());
+ size_t begin = abs_lod[level][idx];
+ size_t end = abs_lod[level][idx + 1];
+ return end - begin;
+}
+
void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) {
auto new_lod = framework::SliceLevels(lod_, level_begin, level_end);
lod_ = new_lod;
@@ -117,144 +125,15 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
PADDLE_ENFORCE_LT(elem_begin, NumElements(level));
PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1);
+ auto abs_lod = framework::ToAbsOffset(lod());
auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end);
lod_ = new_lod;
-}
-
-std::string LoDTensor::SerializeToString() const {
- LoDTensorProto desc;
-
- // set data_type
- if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL);
- if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16);
- if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32);
- if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64);
- // FIXME(dzh): there is no fp16 in standard c++
-
- if (this->type() == typeid(float)) // NOLINT
- desc.set_data_type(DataType::FP32);
- if (this->type() == typeid(double)) // NOLINT
- desc.set_data_type(DataType::FP64);
-
- for (int i = 0; i < dims().size(); ++i) {
- desc.add_dims(dims()[i]);
- }
-
- // set lod information
- desc.set_lod_level(this->NumLevels());
- for (size_t i = 0; i < this->NumLevels(); ++i) {
- LoDInfo* lod = desc.add_levels();
- for (size_t j = 0; j < lod_[i].size(); ++j) {
- lod->add_level(lod_[i][j]);
- }
- }
-
- desc.set_version(0);
-
- std::string desc_bytes = desc.SerializeAsString();
-
- // FIXME(dzh) : implement fix chunk size buffer.
- size_t DESC_SIZE = desc_bytes.size();
- size_t DATA_SIZE = holder_->size() - offset_;
-
- const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t);
- char* buffer =
- static_cast(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE));
-
- // format: desc_size data_size, desc_bytes, data_bytes.
- platform::CPUPlace src_place;
- platform::CPUPlace dst_place;
-
- memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t));
- memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE,
- sizeof(size_t));
- memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place,
- desc_bytes.c_str(), desc_bytes.size());
-
- PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!");
- platform::Place place = holder_->place();
- int element_width = holder_->size() / this->numel();
-
- if (platform::is_cpu_place(place)) {
- memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
- boost::get(place),
- static_cast(holder_->ptr()) + offset_ / element_width,
- DATA_SIZE);
- }
-#ifdef PADDLE_WITH_GPU
- if (platform::is_gpu_place(place)) {
- memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
- boost::get(place),
- static_cast(holder_->ptr()) + offset_ / element_width,
- DATA_SIZE);
- }
-#endif
-
- std::string ret(buffer, BUFFER_SIZE);
- memory::Free(platform::CPUPlace(), buffer);
- return ret;
+ // slice the underlying tensor
+ size_t begin = abs_lod[level][elem_begin];
+ size_t end = abs_lod[level][elem_end];
+ PADDLE_ENFORCE_LT(begin, end, "Cannot shrink, the result tensor is empty.");
+ ShareDataWith(Slice(begin, end));
}
-
-void LoDTensor::DeserializeFromString(const std::string& s,
- const platform::Place& dst_place) {
- size_t DESC_SIZE, BUFFER_SIZE;
- platform::CPUPlace src_place;
-
- memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t));
- memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t),
- sizeof(size_t));
-
- const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2;
-
- // parse LoDTensorDesc
- LoDTensorProto desc;
- desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE);
-
- std::vector dims;
- std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
- this->Resize(make_ddim(dims));
-
- // parse data type
- void* ptr = nullptr;
- if (desc.data_type() == DataType::BOOL)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::INT16)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::INT32)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::INT64)
- ptr = this->mutable_data(dst_place);
- // FIXME(dzh): there is no fp16 in standard c++
-
- if (desc.data_type() == DataType::FP32)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::FP64)
- ptr = this->mutable_data(dst_place);
-
- LoD lod;
- std::vector levels;
- for (int i = 0; i < desc.levels().size(); ++i) {
- auto current_level = desc.levels()[i].level();
- std::copy(current_level.begin(), current_level.end(),
- std::back_inserter(levels));
- lod.emplace_back(levels);
- levels.clear();
- }
-
- this->set_lod(lod);
-
- if (platform::is_cpu_place(dst_place)) {
- memory::Copy(boost::get(dst_place), ptr, src_place,
- s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
- }
-#ifdef PADDLE_WITH_GPU
- if (platform::is_gpu_place(dst_place)) {
- memory::Copy(boost::get(dst_place), ptr, src_place,
- s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
- }
-#endif
-}
-
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index f78a751c53621aa103026b5d8a251966685822bb..f4fe4cdac6019a1899fd3db8e1b6ca588be0d436 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -85,7 +85,9 @@ class LoDTensor : public Tensor {
void set_lod(const LoD& lod) { lod_ = lod; }
- LoD lod() const { return lod_; }
+ const LoD& lod() const { return lod_; }
+
+ LoD* mutable_lod() { return &lod_; }
/*
* Get the start offset and end offset of an element from LoD.
@@ -122,6 +124,12 @@ class LoDTensor : public Tensor {
*/
size_t NumElements(size_t level, size_t idx) const;
+ /*
+ * Get the number of instances in the underlying tensor in the `idx`-th
+ * element.
+ */
+ size_t NumInstancesInElement(size_t level, size_t idx) const;
+
/*
* Shrink levels[level_begin:level_end]
*/
@@ -133,29 +141,45 @@ class LoDTensor : public Tensor {
*/
void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end);
- /**
- * @brief Serialize tensor to char bytes.
- * Please check model_format.md for the format detail.
- * NOTE: GPUTensor will copy data to cpu implicitly.
- * @return return string
- */
-
- // FIXME(dzh) : Currently, this interface should only be used in
- // save/restore model and checkpoint. ParameterServer do not use shape
- // information to do the optimization, as a result, when we serialize
- // parameter/gradient to string, we should serialize the tensor
- // to string in the ps trainer instead of LoDTensor.
- std::string SerializeToString() const;
-
- /**
- * @brief Deserialize char bytes to tensor.
- * @return return string
- */
- void DeserializeFromString(const std::string& s,
- const platform::Place& dst_place);
-
private:
LoD lod_;
};
+
+/*
+ * Expand the `source` to fit the LoD of `lod`. For example, a `source`
+ * LoDTensor is
+ * - LoD: [0, 2]
+ * - tensor: [a0, a1]
+ * a `lod` is
+ * - LoD: [0 3 5]
+ * returns a new LoDTensor
+ * - [a0 a0 a0 a1 a1]
+ */
+template
+LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
+ const platform::Place& place) {
+ LoD abs_lod = ToAbsOffset(lod);
+ const auto& lod_level = lod[level];
+ size_t num_instances = source.dims()[0];
+
+ // new tensor
+ LoDTensor tensor;
+ tensor.set_lod(lod);
+ auto dims = source.dims();
+ dims[0] = lod_level.back();
+ tensor.Resize(dims);
+ tensor.mutable_data(place);
+
+ PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
+ for (size_t ins = 0; ins < num_instances; ins++) {
+ for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
+ tensor.Slice(elem, elem + 1)
+ .CopyFrom(source.Slice(ins, ins + 1), platform::CPUPlace(),
+ platform::CPUDeviceContext());
+ }
+ }
+ return tensor;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc
index b984d620717453456fb15620b4d10c4268be8a94..aa2f6c993d41ae98e0769d470dccad3b410da53e 100644
--- a/paddle/framework/lod_tensor_test.cc
+++ b/paddle/framework/lod_tensor_test.cc
@@ -92,11 +92,14 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
size_t level = 0;
LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.ShrinkInLevel(level, 0, 1);
- EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL);
- EXPECT_EQ(new_lod_tensor.NumElements(0), 1UL);
- EXPECT_EQ(new_lod_tensor.NumElements(1), 2UL);
- EXPECT_EQ(new_lod_tensor.NumElements(2), 5UL);
- ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data());
+ ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
+ ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
+ ASSERT_EQ(new_lod_tensor.NumElements(1), 2UL);
+ ASSERT_EQ(new_lod_tensor.NumElements(2), 5UL);
+ ASSERT_EQ(new_lod_tensor.dims()[0], 12);
+ for (int i = 0; i < 12 * 128; i++) {
+ ASSERT_EQ(new_lod_tensor.data()[i], i);
+ }
level = 1;
new_lod_tensor = lod_tensor_;
@@ -104,23 +107,41 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 3UL);
- ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data());
+ ASSERT_EQ(new_lod_tensor.dims()[0], 7);
+ for (int i = 5 * 128; i < 12 * 128; i++) {
+ ASSERT_EQ(new_lod_tensor.data()[i - 5 * 128], i);
+ }
+
+ LoDTensor t1;
+ t1.set_lod(lod_tensor_.lod());
+ t1.ShareDataWith(lod_tensor_);
+
+ LoDTensor t2;
+ t2.set_lod(lod_tensor_.lod());
+ t2.ShareDataWith(lod_tensor_);
+
+ t1.ShrinkInLevel(0, 1, 2);
+ t2.ShrinkInLevel(0, 0, 1);
+ EXPECT_NE(t1.data(), t2.data());
+ EXPECT_NE(t1.data(), lod_tensor_.data());
}
-TEST_F(LoDTensorTester, SerializeDeserialize) {
- LoDTensor new_lod_tensor = lod_tensor_;
- float* src_ptr = lod_tensor_.data();
- std::string s = lod_tensor_.SerializeToString();
- LoDTensor dst;
- dst.DeserializeFromString(s, platform::CPUPlace());
- float* dst_ptr = dst.data();
- for (int i = 0; i < kLodTensorSize; ++i) {
- EXPECT_EQ(dst_ptr[i], src_ptr[i]);
+TEST(LodExpand, test) {
+ LoD lod{{0, 2}};
+ LoDTensor tensor;
+ tensor.set_lod(lod);
+ tensor.Resize({2, 1});
+ tensor.mutable_data(platform::CPUPlace());
+ tensor.data()[0] = 0;
+ tensor.data()[1] = 1;
+
+ LoD target;
+ target.emplace_back(std::vector{0, 3, 5});
+ auto new_tensor = LodExpand(tensor, target, 0UL, platform::CPUPlace());
+ std::vector result{{0, 0, 0, 1, 1}};
+ for (size_t i = 0; i < 5; i++) {
+ ASSERT_EQ(new_tensor.data()[i], result[i]);
}
-
- ASSERT_EQ(dst.NumElements(0), 2UL);
- ASSERT_EQ(dst.NumElements(1), 3UL);
- ASSERT_EQ(dst.NumElements(2), 8UL);
}
} // namespace framework
diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu
index 11659be02ac340728150cf0a6438db8626c8e611..c79c4d0c721f9e568c937cb9e524e925fcdc83d0 100644
--- a/paddle/framework/lod_tensor_test.cu
+++ b/paddle/framework/lod_tensor_test.cu
@@ -47,31 +47,4 @@ TEST(LoDTensor, LoDInGPU) {
for (size_t i = 0; i < src_lod[0].size(); ++i) {
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
}
-}
-
-TEST(LoDTensor, SerializeDeserialize) {
- paddle::framework::LoDTensor lod_tensor;
- paddle::platform::GPUPlace place(0);
-
- paddle::framework::LoD src_lod;
- src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14});
-
- lod_tensor.Resize({14, 16});
- lod_tensor.mutable_data(place);
-
- lod_tensor.set_lod(src_lod);
- CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
- CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
-
- test<<<1, 8>>>(src_lod[0].data(), src_lod[0].size());
- cudaDeviceSynchronize();
-
- std::string s = lod_tensor.SerializeToString();
- paddle::framework::LoDTensor dst;
- dst.DeserializeFromString(s, place);
- paddle::framework::LoD dst_lod = dst.lod();
-
- for (size_t i = 0; i < dst_lod[0].size(); ++i) {
- CHECK_EQ(src_lod[0].data()[i], dst_lod[0].data()[i] * 2);
- }
-}
+}
\ No newline at end of file
diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc
index 18fabe481dac9c1b70e7c30cb83ec5ee8ac47026..133869e7b58dd2082bd6e099351609f7ed37e96a 100644
--- a/paddle/framework/op_desc.cc
+++ b/paddle/framework/op_desc.cc
@@ -14,9 +14,13 @@ limitations under the License. */
#include "paddle/framework/op_desc.h"
#include
+#include
#include
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
+#include "paddle/framework/program_desc.h"
+
+#include "glog/logging.h"
namespace paddle {
namespace framework {
@@ -24,16 +28,47 @@ namespace framework {
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs) {
- op_desc_.set_type(type);
+ desc_.set_type(type);
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
}
+OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
+ : desc_(desc), need_update_(false) {
+ // restore inputs_
+ int input_size = desc_.inputs_size();
+ for (int i = 0; i < input_size; ++i) {
+ const OpDesc::Var &var = desc_.inputs(i);
+ std::vector &args = inputs_[var.parameter()];
+ int argu_size = var.arguments_size();
+ args.reserve(argu_size);
+ for (int j = 0; j < argu_size; ++j) {
+ args.push_back(var.arguments(j));
+ }
+ }
+ // restore outputs_
+ int output_size = desc_.outputs_size();
+ for (int i = 0; i < output_size; ++i) {
+ const OpDesc::Var &var = desc_.outputs(i);
+ std::vector &args = outputs_[var.parameter()];
+ int argu_size = var.arguments_size();
+ args.reserve(argu_size);
+ for (int j = 0; j < argu_size; ++j) {
+ args.push_back(var.arguments(j));
+ }
+ }
+ // restore attrs_
+ for (const OpDesc::Attr &attr : desc_.attrs()) {
+ std::string attr_name = attr.name();
+ attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
+ }
+}
+
OpDesc *OpDescBind::Proto() {
Flush();
- return &op_desc_;
+ return &desc_;
}
const std::vector &OpDescBind::Input(
@@ -167,23 +202,23 @@ struct SetAttrDescVisitor : public boost::static_visitor {
void OpDescBind::Flush() {
if (need_update_) {
- this->op_desc_.mutable_inputs()->Clear();
+ this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
- auto *input = op_desc_.add_inputs();
+ auto *input = desc_.add_inputs();
input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments());
}
- this->op_desc_.mutable_outputs()->Clear();
+ this->desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) {
- auto *output = op_desc_.add_outputs();
+ auto *output = desc_.add_outputs();
output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments());
}
- this->op_desc_.mutable_attrs()->Clear();
+ this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) {
- auto *attr_desc = op_desc_.add_attrs();
+ auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast(attr.second.which() - 1));
@@ -195,26 +230,26 @@ void OpDescBind::Flush() {
}
}
-using InferShapeFuncMap =
- std::unordered_map>;
-
-static InferShapeFuncMap &InferShapeFuncs() {
- static InferShapeFuncMap *g_map = nullptr;
- if (g_map == nullptr) {
- g_map = new InferShapeFuncMap();
- auto &info_map = OpInfoMap::Instance();
- // all registered kernels
- for (auto &pair : OperatorWithKernel::AllOpKernels()) {
- auto &info = info_map.Get(pair.first);
- // use empty type here to avoid runtime checks.
+static std::once_flag init_infer_shape_funcs;
+
+static void InitInferShapeFuncs() {
+ std::call_once(init_infer_shape_funcs, [] {
+ auto &map = OpInfoMap::Instance();
+ auto &info_map = *map.mutable_map();
+
+ for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
+ auto op_type = kern_pair.first;
+ auto &op_info = info_map.at(op_type);
auto op =
- static_cast(info.Creator()("", {}, {}, {}));
- g_map->insert(
- {pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
+ static_cast(op_info.Creator()("", {}, {}, {}));
+ if (op_info.infer_shape_) { // infer_shape has been registered.
+ continue;
+ }
+ op_info.infer_shape_ = [op](InferShapeContext *ctx) {
+ op->InferShape(ctx);
+ };
}
- }
- return *g_map;
+ });
}
void OpDescBind::CheckAttrs() {
@@ -230,13 +265,13 @@ void OpDescBind::CheckAttrs() {
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
- auto &funcs = InferShapeFuncs();
- auto it = funcs.find(this->Type());
- if (it == funcs.end()) {
- PADDLE_THROW("Operator %s has not been registered", this->Type());
- }
+ VLOG(3) << "CompileTime infer shape on " << Type();
+ InitInferShapeFuncs();
+ auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
+ PADDLE_ENFORCE(static_cast(infer_shape),
+ "%s's infer_shape has not been registered", this->Type());
CompileTimeInferShapeContext ctx(*this, block);
- it->second(&ctx);
+ infer_shape(&ctx);
}
void OpDescBind::InferVarType(BlockDescBind *block) const {
diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h
index 313bf538ac7c947c5e77ca0ead6bb53e6a156478..9b8fe17d6eb8e95c6453a230015f59b84a76095d 100644
--- a/paddle/framework/op_desc.h
+++ b/paddle/framework/op_desc.h
@@ -24,6 +24,7 @@ namespace paddle {
namespace framework {
class BlockDescBind;
+class ProgramDescBind;
class OpDescBind {
public:
@@ -32,11 +33,13 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);
+ OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
+
OpDesc *Proto();
- std::string Type() const { return op_desc_.type(); }
+ std::string Type() const { return desc_.type(); }
- void SetType(const std::string &type) { op_desc_.set_type(type); }
+ void SetType(const std::string &type) { desc_.set_type(type); }
const std::vector &Input(const std::string &name) const;
@@ -117,7 +120,7 @@ class OpDescBind {
return ret_val;
}
- OpDesc op_desc_;
+ OpDesc desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h
index 59a64d71371b546f76eabdeed7e7514e8fb0f84a..d3b1a3b5fa2cf8f6a9571e92a319f3757666657e 100644
--- a/paddle/framework/op_info.h
+++ b/paddle/framework/op_info.h
@@ -25,12 +25,19 @@
namespace paddle {
namespace framework {
+class InferShapeBase {
+ public:
+ virtual ~InferShapeBase() = default;
+ virtual void operator()(InferShapeContext*) const = 0;
+};
+
struct OpInfo {
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_;
+ InferShapeFN infer_shape_;
bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
@@ -87,13 +94,13 @@ class OpInfoMap {
}
}
- const std::unordered_map& map() const {
- return map_;
- }
+ const std::unordered_map& map() const { return map_; }
+
+ std::unordered_map* mutable_map() { return &map_; }
private:
OpInfoMap() = default;
- std::unordered_map map_;
+ std::unordered_map map_;
DISABLE_COPY_AND_ASSIGN(OpInfoMap);
};
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index a67625fa88fd2fbe4db43241ee824519ceac7017..db154e4f76fbec444ae4347523cadd1b6d29d319 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -33,24 +33,6 @@ ExecutionContext::GetEigenDevice() const {
}
#endif
-const Tensor* GetTensorFromVar(const Variable* var) {
- if (var->IsType()) {
- return &var->Get();
- }
- PADDLE_ENFORCE(var->IsType(),
- "The Input must be LoDTensor or Tensor.");
- return &var->Get();
-}
-
-Tensor* GetTensorFromVar(Variable* var) {
- if (var->IsType()) {
- return var->GetMutable();
- }
- PADDLE_ENFORCE(var->IsType(),
- "The Input must be LoDTensor or Tensor.");
- return var->GetMutable();
-}
-
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
@@ -204,6 +186,30 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
+static const Tensor* GetTensorFromVar(const Variable* var) {
+ const Tensor* t = nullptr;
+ if (var->IsType()) {
+ t = &(var->Get());
+ } else if (var->IsType()) {
+ t = &(var->Get().value());
+ } else {
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
+ }
+ return t;
+}
+
+static Tensor* GetMutableTensorFromVar(Variable* var) {
+ Tensor* t = nullptr;
+ if (var->IsType()) {
+ t = var->GetMutable();
+ } else if (var->IsType()) {
+ t = var->GetMutable()->mutable_value();
+ } else {
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
+ }
+ return t;
+}
+
template <>
const Tensor* ExecutionContext::Input(const std::string& name) const {
auto* var = InputVar(name);
@@ -227,7 +233,7 @@ const std::vector ExecutionContext::MultiInput(
template <>
Tensor* ExecutionContext::Output(const std::string& name) const {
auto var = OutputVar(name);
- return var == nullptr ? nullptr : var->GetMutable();
+ return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
}
template <>
@@ -240,7 +246,7 @@ std::vector ExecutionContext::MultiOutput(
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr
- : var->GetMutable();
+ : GetMutableTensorFromVar(var);
});
return res;
}
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 0d0304ac9e13089ef533b0a47f0ec989c8fd7078..aa79f16df82ab9d81e093af60b730d9aacd09568 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h"
+#include "paddle/framework/selected_rows.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
@@ -60,9 +61,6 @@ inline std::string GradVarName(const std::string& var_name) {
class OperatorBase;
class ExecutionContext;
-extern const Tensor* GetTensorFromVar(const Variable* var);
-extern Tensor* GetTensorFromVar(Variable* var);
-
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
@@ -414,7 +412,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
private:
DDim GetDim(const std::string& name) const override {
- return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
+ auto var = block_.FindVarRecursive(name);
+ PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
+ return framework::make_ddim(var->Shape());
}
void SetDim(const std::string& name, const DDim& dim) override {
@@ -511,28 +511,26 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
private:
- template
- Tensor* GetTensor(const std::string& name) const {
- Tensor* t = nullptr;
- auto* var = scope_.FindVar(name);
- if (!var->IsType() && !var->IsType()) {
- if (Allocate) {
- t = var->GetMutable();
- } else {
- PADDLE_THROW("Variable(%s) should be tensor", name);
- }
+ DDim GetDim(const std::string& name) const override {
+ Variable* var = scope_.FindVar(name);
+ if (var->IsType()) {
+ return var->Get().dims();
+ } else if (var->IsType()) {
+ return var->Get().GetCompleteDims();
} else {
- t = GetTensorFromVar(scope_.FindVar(name));
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
- return t;
- }
-
- DDim GetDim(const std::string& name) const override {
- return GetTensor(name)->dims();
}
void SetDim(const std::string& name, const DDim& dim) override {
- GetTensor(name)->Resize(dim);
+ Variable* var = scope_.FindVar(name);
+ if (var->IsType()) {
+ var->GetMutable()->Resize(dim);
+ } else if (var->IsType()) {
+ var->GetMutable()->set_height(dim[0]);
+ } else {
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
+ }
}
const OperatorBase& op_;
@@ -638,7 +636,9 @@ class OperatorWithKernel : public OperatorBase {
});
}
- virtual void InferShape(InferShapeContext* ctx) const = 0;
+ virtual void InferShape(InferShapeContext* ctx) const {
+ OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
+ }
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
@@ -655,11 +655,14 @@ class OperatorWithKernel : public OperatorBase {
t = &var->Get();
} else if (var->IsType()) {
t = &var->Get();
+ } else if (var->IsType()) {
+ t = &(var->Get().value());
}
if (t != nullptr) {
int tmp = static_cast(ToDataType(t->type()));
+ VLOG(3) << "Input " << ipt_name << " with data_type " << tmp;
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
- "DataType of Paddle Op must be same.");
+ "DataType of Paddle Op %s must be same.", Type());
data_type = tmp;
}
}
diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc
index c358f1a2b6ee3174b8c336ba1d212be7c5aa15c6..3c07621293389fc7803b0295d9d30b2c12d6e327 100644
--- a/paddle/framework/operator_test.cc
+++ b/paddle/framework/operator_test.cc
@@ -237,12 +237,12 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope;
- scope.Var("x0")->GetMutable();
- scope.Var("x1")->GetMutable();
- scope.Var("x2")->GetMutable();
- scope.Var("k0")->GetMutable();
- scope.Var("y0")->GetMutable();
- scope.Var("y1")->GetMutable();
+ scope.Var("x0")->GetMutable();
+ scope.Var("x1")->GetMutable();
+ scope.Var("x2")->GetMutable();
+ scope.Var("k0")->GetMutable();
+ scope.Var("y0")->GetMutable();
+ scope.Var("y1")->GetMutable();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op->Run(scope, cpu_device_context);
diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc
index 8e99bba81117c9cc50227122527d6ab9a421c251..82f16a7c8b9de2b46dcae4288d999bc5c644aede 100644
--- a/paddle/framework/program_desc.cc
+++ b/paddle/framework/program_desc.cc
@@ -19,9 +19,9 @@ namespace paddle {
namespace framework {
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
- auto *b = prog_.add_blocks();
+ auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID());
- b->set_idx(prog_.blocks_size() - 1);
+ b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get();
}
@@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Flush();
}
- return &prog_;
+ return &desc_;
}
ProgramDescBind::ProgramDescBind() {
- auto *block = prog_.mutable_blocks()->Add();
+ auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block));
}
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
- prog_ = o.prog_;
+ desc_ = o.desc_;
- for (int i = 0; i < prog_.blocks_size(); ++i) {
- auto *block = prog_.mutable_blocks(i);
+ for (int i = 0; i < desc_.blocks_size(); ++i) {
+ auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
}
}
+
+ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
+ PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
+ "Fail to parse program_desc from binary string.");
+ for (auto &block_desc : *desc_.mutable_blocks()) {
+ blocks_.emplace_back(new BlockDescBind(this, &block_desc));
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h
index dc4cd7cc735b5e4e3466d9b82dc5eb8647c80ef9..b6e76515a5af0f1ff663442faebc50e1c5cc2520 100644
--- a/paddle/framework/program_desc.h
+++ b/paddle/framework/program_desc.h
@@ -31,6 +31,8 @@ class ProgramDescBind {
ProgramDescBind(const ProgramDescBind &o);
+ explicit ProgramDescBind(const std::string &binary_str);
+
BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
@@ -40,7 +42,7 @@ class ProgramDescBind {
ProgramDesc *Proto();
private:
- ProgramDesc prog_;
+ ProgramDesc desc_;
std::vector> blocks_;
};
diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc
index c9709a2d3f1d9e0be2bda1e8e9e7835ca49141b1..d28c2a0bff932f5aa37c69231495895dacb07bb3 100644
--- a/paddle/framework/program_desc_test.cc
+++ b/paddle/framework/program_desc_test.cc
@@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) {
};
ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
- ASSERT_EQ(3, global_block_copy->LocalVarNames().size());
+ ASSERT_EQ(3UL, global_block_copy->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
@@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) {
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}
+
+TEST(ProgramDescBind, serialize_and_deserialize) {
+ ProgramDescBind program_origin;
+ auto* global_block = program_origin.Block(0);
+ auto* x = global_block->Var("X");
+ x->SetType(VarDesc_VarType_LOD_TENSOR);
+ x->SetLoDLevel(0);
+ x->SetDataType(FP32);
+ x->SetShape({1000, 784});
+
+ auto* y = global_block->Var("Y");
+ y->SetType(VarDesc_VarType_LOD_TENSOR);
+ y->SetLoDLevel(0);
+ y->SetDataType(FP32);
+ y->SetShape({784, 100});
+
+ auto* op = global_block->AppendOp();
+ op->SetType("mul");
+ op->SetInput("X", {x->Name()});
+ op->SetInput("Y", {y->Name()});
+
+ auto* out = global_block->Var("Out");
+ out->SetType(VarDesc_VarType_LOD_TENSOR);
+ op->SetOutput("Y", {out->Name()});
+
+ std::string binary_str;
+ program_origin.Proto()->SerializeToString(&binary_str);
+
+ ProgramDescBind program_restored(binary_str);
+ auto* global_block_restored = program_restored.Block(0);
+ ASSERT_NE(global_block, global_block_restored);
+
+ auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
+ ASSERT_TRUE(global_block_restored->HasVar(name));
+ auto* restored = global_block_restored->Var(name);
+ ASSERT_NE(restored, var_before);
+ ASSERT_EQ(restored->Name(), var_before->Name());
+ ASSERT_EQ(restored->GetType(), var_before->GetType());
+ ASSERT_EQ(restored->Shape(), var_before->Shape());
+ ASSERT_EQ(restored->Proto()->SerializeAsString(),
+ var_before->Proto()->SerializeAsString());
+ };
+
+ ASSERT_EQ(global_block->LocalVarNames(),
+ global_block_restored->LocalVarNames());
+ ASSERT_EQ(3UL, global_block_restored->LocalVarNames().size());
+ assert_same_var("X", x);
+ assert_same_var("Y", y);
+ assert_same_var("Out", out);
+
+ for (size_t i = 0; i < global_block->OpSize(); ++i) {
+ auto op_origin = global_block->Op(i);
+ auto op_restored = global_block->Op(i);
+
+ ASSERT_EQ(op_origin->Type(), op_restored->Type());
+ ASSERT_EQ(op_origin->Inputs(), op_restored->Inputs());
+ ASSERT_EQ(op_origin->Outputs(), op_restored->Outputs());
+
+ ASSERT_EQ(op_restored->Proto()->SerializeAsString(),
+ op_origin->Proto()->SerializeAsString());
+ }
+}
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/saver.proto b/paddle/framework/saver.proto
deleted file mode 100644
index 90a191a6a79250761489b68916b1fa09116830f2..0000000000000000000000000000000000000000
--- a/paddle/framework/saver.proto
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-package paddle.framework;
-
-import "framework.proto";
-
-/**
- * This file contains necessary information for model, checkpoint.
- * etc.
- */
-
-message LoDInfo { repeated int64 level = 1; }
-
-/**
- * Save the LoDTensorDesc information through LoDTensorProto, its data memory
- * is copyed to c buffer immediately. See model_format.md for details.
- */
-
-message LoDTensorProto {
- optional DataType data_type = 1;
- repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
- repeated LoDInfo levels = 3;
- optional int32 lod_level = 4 [ default = 0 ];
- optional int32 version = 5;
-}
diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h
index cd9078137132669c7265ce3972f2c6df996fa366..0332b91323e3a4b4b80e02302ad3dcafe0986cde 100644
--- a/paddle/framework/selected_rows.h
+++ b/paddle/framework/selected_rows.h
@@ -23,7 +23,10 @@ class SelectedRows {
value_.reset(new Tensor());
}
- SelectedRows() { value_.reset(new Tensor()); }
+ SelectedRows() {
+ height_ = 0;
+ value_.reset(new Tensor());
+ }
platform::Place place() const { return value_->place(); }
@@ -37,6 +40,8 @@ class SelectedRows {
const Vector& rows() const { return rows_; }
+ Vector* mutable_rows() { return &rows_; }
+
void set_rows(const Vector& rows) { rows_ = rows; }
DDim GetCompleteDims() const {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index e31472327dbca45dc12ea2c9e494beddd36860dc..9d2dc6a32bb2d4f6368fd9c7264c55fb9588819c 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -132,6 +132,8 @@ class Tensor {
std::type_index type() const { return holder_->type(); }
+ size_t memory_size() const;
+
private:
inline void check_memory_size() const;
diff --git a/paddle/framework/tensor_array.cc b/paddle/framework/tensor_array.cc
index 4c82c3638351c41df26503e2a26b5a4bb5822a67..0947e33548130a923e998f8bad68db00097af909 100644
--- a/paddle/framework/tensor_array.cc
+++ b/paddle/framework/tensor_array.cc
@@ -20,6 +20,8 @@
#include
#include
+#include "paddle/framework/eigen.h"
+
namespace paddle {
namespace framework {
@@ -104,10 +106,10 @@ void TensorArray::Write(size_t index, const LoDTensor& value) {
values_.resize(index + 1);
}
+ values_[index].set_lod(value.lod());
values_[index].Resize(value.dims());
- values_[index].mutable_data(platform::CPUPlace());
- values_[index].CopyFrom(value, platform::CPUPlace(),
- platform::CPUDeviceContext());
+ values_[index].mutable_data(value.place());
+ values_[index].CopyFrom(value, value.place(), platform::CPUDeviceContext());
}
void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
@@ -116,6 +118,7 @@ void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
values_.resize(index + 1);
}
+ values_[index].set_lod(value.lod());
values_[index].ShareDataWith(value);
}
@@ -144,6 +147,155 @@ DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level,
return unpacker.meta;
}
+LoDTensor TensorArray::LodPack(size_t level) const {
+ PADDLE_ENFORCE_GT(size(), 0UL, "no time step exists");
+ // the levels should be no less than 2
+ LoDTensor merged;
+ const LoDTensor *pre, *cur;
+ pre = &Read(0);
+
+ for (size_t step = 1; step < size(); step++) {
+ cur = &Read(step);
+ PADDLE_ENFORCE_GT(cur->NumLevels(), 0);
+ PADDLE_ENFORCE_GT(pre->NumLevels(), 0);
+ PADDLE_ENFORCE_EQ(pre->NumLevels(), cur->NumLevels());
+ PADDLE_ENFORCE_EQ(pre->NumElements(level), cur->NumElements(level));
+
+ merged = LodPackTwo(*pre, *cur, level);
+ pre = &merged;
+ }
+ return merged;
+}
+
+/*
+ * NOTE currently, only the lowest level supports packing.
+ * The lowest LoD will be changed, while the relative offsets in levels above
+ * stay unchanged.
+ *
+ * previous step : [0] [1] [3]
+ * current step: [0 1 2] [2 3] []
+ * packed to
+ * [0 0] [0 1] [0 2] [1 2] [1 3] [3]
+ */
+LoDTensor TensorArray::LodPackTwo(const LoDTensor& pre, const LoDTensor& cur,
+ size_t level) const {
+ PADDLE_ENFORCE_EQ(pre.NumLevels(), cur.NumLevels());
+ PADDLE_ENFORCE_EQ(pre.NumLevels(), level + 1,
+ "Only the lowest LoD level supports pack temporarily.");
+ // calculate the result tensor's shape first
+ size_t num_instances = 0;
+ for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
+ size_t prefix_size = pre.NumElements(level, elem);
+ size_t num_candidates = cur.NumElements(level, elem);
+ if (num_candidates > 0) {
+ num_instances += num_candidates * (prefix_size + 1);
+ } else {
+ num_instances += prefix_size;
+ }
+ }
+
+ auto res_dims = pre.dims();
+ res_dims[0] = num_instances;
+ LoDTensor result;
+ result.Resize(res_dims);
+ result.mutable_data(cur.place());
+
+ Vector last_lod_level;
+ // copy data
+ size_t index = 0;
+ last_lod_level.push_back(index);
+ for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
+ size_t prefix_size = pre.NumElements(level, elem);
+ size_t num_candidates = cur.NumElements(level, elem);
+
+ // slice the prefix Tensor
+ LoDTensor prefix = pre;
+ prefix.ShrinkInLevel(level, elem, elem + 1);
+ LoDTensor candidate = cur;
+ if (num_candidates > 0) {
+ candidate.ShrinkInLevel(level, elem, elem + 1);
+ } else { // just push prefix
+ result.Slice(index, index + prefix_size)
+ .CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
+ index += prefix_size;
+ last_lod_level.push_back(index);
+ }
+ for (size_t candi = 0; candi < num_candidates; candi++) {
+ // TODO(superjom) support GPU
+ result.Slice(index, index + prefix_size)
+ .CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
+ index += prefix_size;
+ // copy candidate record
+ result.Slice(index, index + 1)
+ .CopyFrom(candidate.Slice(candi, candi + 1), result.place(),
+ platform::CPUDeviceContext());
+ index++;
+ last_lod_level.push_back(index);
+ }
+ }
+
+ // update lod
+ auto lod = cur.lod();
+ lod.back() = last_lod_level;
+ result.set_lod(lod);
+ return result;
+}
+
+/*
+ * source [0 1 2] [3 4] [5 6 7] will be transformd to a list of LoDTensors such
+ * as
+ * [0 3 5] [1 4 6] [2 7] with 1-level LoDs:
+ * - [0 1 2 3]
+ * - [0 1 2 3]
+ * - [0 1 1 2], the [1,1) here means the second sequence is empty
+ *
+ * NOTE Unpack a LoDTensor in this approach may result in a big LoD.
+ */
+void TensorArray::LodUnpack(const LoDTensor& source, size_t level) {
+ PADDLE_ENFORCE_EQ(level, source.NumLevels() - 1,
+ "only the lowest LoD level supports unpack.");
+ const size_t non_empty_instances = source.dims()[0];
+ size_t index = 0;
+ Vector lowest_lod_level;
+ lowest_lod_level.push_back(index);
+
+ for (size_t step = 0; step < non_empty_instances; step++) {
+ size_t num_instances = 0;
+ for (size_t id = 0; id < source.NumElements(level); id++) {
+ auto instance = source;
+ instance.ShrinkInLevel(level, id, id + 1);
+ if (static_cast(instance.dims()[0]) > step) {
+ num_instances++;
+ index++;
+ }
+ lowest_lod_level.push_back(index);
+ }
+
+ // create tensor for this time step
+ LoDTensor tensor;
+ auto dims = source.dims();
+ dims[0] = num_instances;
+ // set lod
+ auto lod = source.lod();
+ lod.back() = lowest_lod_level;
+ tensor.set_lod(lod);
+
+ index = 0;
+ for (size_t id = 0; id < source.NumElements(level); id++) {
+ auto instance = source;
+ instance.ShrinkInLevel(level, id, id + 1);
+ if (static_cast(instance.dims()[0]) > step) {
+ // copy this instance
+ tensor.Slice(index, index + 1)
+ .CopyFrom(instance.Slice(step, step + 1), tensor.place(),
+ platform::CPUDeviceContext());
+ index++;
+ }
+ }
+ Write(step, tensor);
+ }
+}
+
LoDTensor TensorArray::Stack() const {
LoDTensor result;
if (size() == 0) return result;
diff --git a/paddle/framework/tensor_array.h b/paddle/framework/tensor_array.h
index 046ecb5221b7ed9d88e5017348ee8fcde23c7677..78fad8cab7e27a7f07ca542c2a083460ee9e2b79 100644
--- a/paddle/framework/tensor_array.h
+++ b/paddle/framework/tensor_array.h
@@ -86,6 +86,16 @@ class TensorArray {
*/
DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend);
+ /*
+ * Pack an array of LoDTensors to a LoDTensor.
+ */
+ LoDTensor LodPack(size_t level) const;
+
+ /*
+ * Unpack a LoDTensor to an array of LoDTensors.
+ */
+ void LodUnpack(const LoDTensor &source, size_t level);
+
/*
* Pack the values into a tensor with rank one higher than each tensor in
* values.
@@ -111,6 +121,9 @@ class TensorArray {
protected:
void Unstack(const LoDTensor &source, bool data_shared) const;
+ LoDTensor LodPackTwo(const LoDTensor &pre, const LoDTensor &cur,
+ size_t level) const;
+
private:
mutable std::vector values_;
}; // class TensorArray
diff --git a/paddle/framework/tensor_array_test.cc b/paddle/framework/tensor_array_test.cc
index 9470ac5e6ed714d5ba63f3743e683af7f8edd4b0..83b52b442daf9b2f1fc40f23e458fcb67c5040e8 100644
--- a/paddle/framework/tensor_array_test.cc
+++ b/paddle/framework/tensor_array_test.cc
@@ -126,5 +126,57 @@ TEST_F(TensorArrayTester, size) {
ASSERT_EQ(ta.size(), static_cast(batch_size));
}
+TEST(TensorArray, LodPack) {
+ // three time steps, each step stores a LoDTensors
+ // - [0] [1]
+ // - [2 3], [4 5]
+ // - [6 7] [] [8], [9, 10]
+ // try to get a LoDTensor with content:
+ // - [0 2 6]
+ // - [0 2 7]
+ // - [0 3]
+ // - [1 4 8]
+ // - [1 5 9]
+ // - [1 5 10]
+ std::array tensors;
+ tensors[0].Resize(make_ddim({2, 1}));
+ tensors[1].Resize(make_ddim({4, 1}));
+ tensors[2].Resize(make_ddim({5, 1}));
+ int index = 0;
+ for (auto& t : tensors) {
+ t.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < t.dims()[0]; i++) {
+ t.data()[i] = index;
+ index++;
+ }
+ }
+
+ std::array lods;
+ std::vector> levels{
+ {0, 1, 2}, {0, 2, 4}, {0, 2, 2, 3, 5}};
+ for (int i = 0; i < 3; i++) {
+ lods[i].emplace_back(levels[i].begin(), levels[i].end());
+ }
+
+ TensorArray ta;
+ for (int i = 0; i < 3; i++) {
+ tensors[i].set_lod(lods[i]);
+ ta.Write(i, tensors[i]);
+ }
+
+ auto merged = ta.LodPack(0);
+
+ std::vector target_tensor_data{{0, 2, 6, // 0
+ 0, 2, 7, // 1
+ 0, 3, // 2
+ 1, 4, 8, // 3
+ 1, 5, 9, // 5
+ 1, 5, 10}};
+ EXPECT_EQ(merged.dims()[0], (int)target_tensor_data.size());
+ for (size_t i = 0; i < target_tensor_data.size(); i++) {
+ EXPECT_EQ(target_tensor_data[i], merged.data()[i]);
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index f6e801bbb4a056b5590da95a4b140cb90638f322..29ac683f48fcde4dd3b5ad7f04b5d1d7434706ba 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -62,12 +62,16 @@ inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(
- holder_->size(), numel() * SizeOfType(type()) + offset_,
+ holder_->size(), memory_size() + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
}
+inline size_t Tensor::memory_size() const {
+ return holder_ == nullptr ? 0UL : numel() * SizeOfType(type());
+}
+
template
inline const T* Tensor::data() const {
check_memory_size();
diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h
index 00da7289394cf18e013220a4bedde2c182f6a4a4..c38c4a8ae9a46c8bda913e7643e812592de68e6e 100644
--- a/paddle/framework/type_defs.h
+++ b/paddle/framework/type_defs.h
@@ -28,6 +28,8 @@ class OperatorBase;
class OpDescBind;
class BlockDescBind;
class BlockDesc;
+class InferShapeContext;
+
using VariableNameMap = std::map>;
// The order should be as same as framework.proto
@@ -49,5 +51,7 @@ using GradOpMakerFN = std::function>(
using InferVarTypeFN = std::function;
+using InferShapeFN = std::function;
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h
index 929de1f836fa906966ff125c70380d85d062afdf..70daa20e8d99abc5759655adf538a8c197e9ec6a 100644
--- a/paddle/framework/var_desc.h
+++ b/paddle/framework/var_desc.h
@@ -59,6 +59,8 @@ class VarDescBind {
desc_.set_type(VarDesc::LOD_TENSOR);
}
+ explicit VarDescBind(const VarDesc &desc) : desc_(desc) {}
+
VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); }
diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h
index a80f0e66b5a59bf95efc200d159ad5dd9cf4111a..cde5ec2413ad01a0396e19fa617688af0eafbc75 100644
--- a/paddle/framework/variable.h
+++ b/paddle/framework/variable.h
@@ -46,6 +46,8 @@ class Variable {
std::type_index(typeid(T)) == std::type_index(holder_->Type());
}
+ void Clear() { holder_.reset(); }
+
private:
struct Placeholder {
virtual ~Placeholder() {}
diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f577616230be65e9581cf8f3ed5f63a77c7c3e21
--- /dev/null
+++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
@@ -0,0 +1,318 @@
+/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "MKLDNNBatchNormLayer.h"
+
+using namespace mkldnn; // NOLINT
+typedef memory::format format;
+
+namespace paddle {
+
+REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
+
+const real MKLDNNBatchNormLayer::EPS = 1E-5;
+
+bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
+ const ParameterMap& parameterMap) {
+ if (!MKLDNNLayer::init(layerMap, parameterMap)) {
+ return false;
+ }
+
+ // first one is input layer
+ // the other two are created in config_parser.py saving moving mean and var
+ CHECK_EQ(inputLayers_.size(), 3U);
+ CHECK_EQ(inputLayers_.size(), parameters_.size());
+ CHECK_EQ(inputLayers_.size(), size_t(config_.inputs_size()));
+
+ const ImageConfig& conf = config_.inputs(0).image_conf();
+ ic_ = conf.channels();
+ ih_ = inputLayers_[0]->getOutput().getFrameHeight();
+ iw_ = inputLayers_[0]->getOutput().getFrameWidth();
+ if (iw_ == 0 && ih_ == 0) {
+ iw_ = conf.img_size();
+ ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
+ }
+ oc_ = ic_;
+ oh_ = ih_;
+ ow_ = iw_;
+ if (config_.has_use_global_stats()) {
+ useGlobalStats_ = config_.use_global_stats();
+ }
+ movingAvgFraction_ = config_.moving_average_fraction();
+ VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
+ << " --- global stats";
+ VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
+
+ initWeight();
+ movingMean_.reset(new Weight(oc_, 1, parameters_[1], 0));
+ movingVar_.reset(new Weight(oc_, 1, parameters_[2], 0));
+ return true;
+}
+
+void MKLDNNBatchNormLayer::initWeight() {
+ weight_.reset(new Weight(1, oc_, parameters_[0]));
+ if (biasParameter_.get() != NULL) {
+ biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_));
+ }
+ CHECK_EQ(weight_ != nullptr, biases_ != nullptr)
+ << "only support have both weight and bias, or neither";
+ if (weight_ && weight_->getW()) {
+ CHECK(biases_ && biases_->getW());
+ valueScaleShift_ = Matrix::create(2, oc_, false, false);
+ valueScaleShift_->zeroMem();
+ VectorPtr scale(new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), 0));
+ VectorPtr shift(
+ new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), oc_));
+ const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_VALUE);
+ const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_VALUE);
+ scale->copyFrom(*wgt);
+ shift->copyFrom(*bias);
+ wgt->setData(valueScaleShift_->getData());
+ bias->setData(valueScaleShift_->getData() + oc_);
+ }
+ if (weight_ && weight_->getWGrad()) {
+ CHECK(biases_ && biases_->getWGrad());
+ gradScaleShift_ = Matrix::create(2, oc_, false, false);
+ gradScaleShift_->zeroMem();
+ const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_GRADIENT);
+ const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_GRADIENT);
+ wgt->setData(gradScaleShift_->getData());
+ bias->setData(gradScaleShift_->getData() + oc_);
+ }
+}
+
+void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
+ if (hasInitedWgt_) {
+ return;
+ }
+ // prepare mean and var if necessary
+ if (useGlobalStats_) {
+ CHECK(mean_);
+ CHECK(var_);
+ mean_->copyFrom(*(movingMean_->getW()));
+ var_->copyFrom(*(movingVar_->getW()));
+ }
+ hasInitedWgt_ = true;
+}
+
+void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
+ // calculating and saving moving mean and variance
+ CHECK_EQ(useGlobalStats_, false);
+ movingMean_->getW()->add(
+ *mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
+ // here var is v^2
+ movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
+}
+
+void MKLDNNBatchNormLayer::reshape(
+ int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
+ reshapeInput(bs, ih, iw);
+ oh = ih;
+ ow = ow;
+ // ic_ and oc can not be changed
+ CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
+ << "Input channel can not be changed";
+ reshapeOutput(oh, ow);
+ resizeOutput(bs, oc * oh * ow);
+ printSizeInfo();
+}
+
+void MKLDNNBatchNormLayer::resetFwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) {
+ // In training phase, it will always calculate mean and var,
+ // so useGlobalStats must be false.
+ // In scoring phase, it depends on useGlobalStats choice.
+ if (passType_ != PASS_TEST && useGlobalStats_ == true) {
+ LOG(WARNING) << "use_global_stats is invalid setting in training phase";
+ useGlobalStats_ = false;
+ }
+
+ resetFwdBuffers(in, wgt, out);
+
+ resetFwdPD(fwdPD_, in, wgt, out);
+
+ resetFwdPipeline(pipeline, fwdPD_, in, wgt, out);
+}
+
+void MKLDNNBatchNormLayer::resetBwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) {
+ std::shared_ptr pd;
+
+ resetBwdBuffers(in, wgt, out);
+
+ resetBwdPD(pd, in, wgt, out);
+
+ resetBwdPipeline(pipeline, pd, in, wgt, out);
+}
+
+void MKLDNNBatchNormLayer::forward(PassType passType) {
+ MKLDNNLayer::forward(passType);
+
+ // calculate and save moving mean and variance
+ if (passType_ != PASS_TEST) {
+ calMovingMeanAndVar();
+ }
+}
+
+void MKLDNNBatchNormLayer::updateWeights(const UpdateCallback& callback) {
+ weight_->getParameterPtr()->incUpdate(callback);
+ if (biases_ && biases_->getWGrad()) {
+ biases_->getParameterPtr()->incUpdate(callback);
+ }
+}
+
+void MKLDNNBatchNormLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ resetInValue(in);
+
+ memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_};
+ CHECK(in);
+ auto outPD =
+ MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_);
+ resetOutValue(out, outPD);
+
+ if (valueScaleShift_) {
+ auto pd = MKLDNNMatrix::createPrimitiveDesc({2, oc_}, format::nc, engine_);
+ resetWithMatrix(wgt, valueScaleShift_, pd);
+ }
+ if (passType_ != PASS_TEST || useGlobalStats_) {
+ auto pd = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
+ mean_ = MKLDNNMatrix::create(pd);
+ var_ = MKLDNNMatrix::create(pd);
+ }
+}
+
+void MKLDNNBatchNormLayer::resetFwdPD(
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr in,
+ MKLDNNMatrixPtr wgt,
+ MKLDNNMatrixPtr out) {
+ flags_ = 0u;
+ prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring
+ : prop_kind::forward_training;
+ if (useGlobalStats_) {
+ flags_ = (flags_ | batch_normalization_flag::use_global_stats);
+ }
+ if (wgt) {
+ flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
+ }
+ auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
+ pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
+ // TODO(TJ): use check macro
+ CHECK(out);
+ CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc());
+ if (wgt) {
+ CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc());
+ }
+ if (passType_ != PASS_TEST || useGlobalStats_) {
+ CHECK(mean_);
+ CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
+ CHECK(var_);
+ CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
+ }
+}
+
+void MKLDNNBatchNormLayer::resetFwdPipeline(
+ std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ if (passType_ == PASS_TEST) {
+ if (useGlobalStats_) {
+ fwd_.reset(wgt != nullptr ? new bn_fwd(*pd,
+ *in,
+ (const primitive::at)(*mean_),
+ (const primitive::at)(*var_),
+ *wgt,
+ *out)
+ : new bn_fwd(*pd,
+ *in,
+ (const primitive::at)(*mean_),
+ (const primitive::at)(*var_),
+ *out));
+ } else {
+ fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out)
+ : new bn_fwd(*pd, *in, *out));
+ }
+ } else {
+ CHECK_EQ(useGlobalStats_, false)
+ << "useGlobalStats should be false in training";
+ fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out, *mean_, *var_)
+ : new bn_fwd(*pd, *in, *out, *mean_, *var_));
+ }
+ pipeline.push_back(*fwd_);
+}
+
+void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ CHECK(inVal_ && outVal_);
+ resetOutGrad(out, outVal_->getPrimitiveDesc());
+ resetInGrad(in, inVal_->getPrimitiveDesc());
+ if (gradScaleShift_) {
+ CHECK(wgtVal_);
+ resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc());
+ }
+}
+
+void MKLDNNBatchNormLayer::resetBwdPD(
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ pd = nullptr;
+ if (in == nullptr) {
+ return;
+ }
+ CHECK(out);
+ CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc());
+ auto md = in->getMemoryDesc();
+ auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
+ pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
+ // TODO(TJ): use check macro
+ CHECK(wgt);
+ CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc());
+ CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
+ CHECK(mean_);
+ CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
+ CHECK(var_);
+ CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
+}
+
+void MKLDNNBatchNormLayer::resetBwdPipeline(
+ std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ if (pd == nullptr) {
+ return;
+ }
+ CHECK(inVal_);
+ bwdData_.reset(
+ wgt && wgtVal_
+ ? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt)
+ : new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in));
+ pipeline.push_back(*bwdData_);
+}
+
+} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.h b/paddle/gserver/layers/MKLDNNBatchNormLayer.h
new file mode 100644
index 0000000000000000000000000000000000000000..456c0424ecb8dde17f98a900c5d77268cc672e34
--- /dev/null
+++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.h
@@ -0,0 +1,138 @@
+/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "MKLDNNLayer.h"
+#include "mkldnn.hpp"
+
+namespace paddle {
+typedef mkldnn::batch_normalization_forward bn_fwd;
+typedef mkldnn::batch_normalization_backward bn_bwd;
+
+/**
+ * @brief A subclass of MKLDNNLayer BatchNorm layer.
+ *
+ * The config file api is mkldnn_batch_norm
+ */
+class MKLDNNBatchNormLayer : public MKLDNNLayer {
+protected:
+ // save forward primitive_desc, which can be used backward
+ std::shared_ptr fwdPD_;
+
+ // Epsilon value used in the batch normalization formula.
+ static const real EPS;
+ // weight and bias in paddle
+ std::unique_ptr weight_;
+ std::unique_ptr biases_;
+ // mkldnn use a large buffer store both scale and shift
+ // which are weight and bias in paddle corresponding.
+ MatrixPtr valueScaleShift_;
+ MatrixPtr gradScaleShift_;
+ // Moving average of mean.
+ std::unique_ptr movingMean_;
+ // Moving average of variance.
+ std::unique_ptr movingVar_;
+
+ // if useGlobalStats_ is true, will use the loaded mean and variance.
+ // otherwise, calculate mean and variance in every mini-batch.
+ bool useGlobalStats_;
+ // used in MKLDNN primitive desc
+ unsigned flags_;
+ // use to compute moving mean and variance.
+ real movingAvgFraction_;
+ // whether the weight has been init
+ bool hasInitedWgt_;
+
+ // local mean and variance
+ // when useGlobalStats_ they are loaded from moving mean and variance
+ // when do not useGlobalStats_ they are calculated from this mini-batch
+ MKLDNNMatrixPtr mean_;
+ MKLDNNMatrixPtr var_;
+
+public:
+ explicit MKLDNNBatchNormLayer(const LayerConfig& config)
+ : MKLDNNLayer(config), useGlobalStats_(true), hasInitedWgt_(false) {}
+
+ ~MKLDNNBatchNormLayer() {}
+
+ bool init(const LayerMap& layerMap,
+ const ParameterMap& parameterMap) override;
+
+ void forward(PassType passType) override;
+
+ void reshape(
+ int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
+
+ void resetFwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) override;
+
+ void resetBwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) override;
+
+ void updateWeights(const UpdateCallback& callback) override;
+
+ void convertWeightsFromPaddle() override;
+
+protected:
+ void initWeight();
+ /**
+ * cal moving mean and variance.
+ * moving = moving * AvgFraction + local * (1 - AvgFraction)
+ */
+ void calMovingMeanAndVar();
+ /**
+ * Forward functions: reset buffers(input, weight, output),
+ * reset primitive descriptor,
+ * reset pipeline.
+ */
+ void resetFwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+ void resetFwdPD(std::shared_ptr& pd,
+ MKLDNNMatrixPtr in,
+ MKLDNNMatrixPtr wgt,
+ MKLDNNMatrixPtr out);
+ void resetFwdPipeline(std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+
+ /**
+ * Backward functions: reset buffers(input, weight, output),
+ * reset primitive descriptor,
+ * reset pipeline.
+ */
+ void resetBwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+ void resetBwdPD(std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+ void resetBwdPipeline(std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+};
+
+} // namespace paddle
diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp
index 0a19fe23336ea943cb8a572dc40f8c0fbbd7236a..73b7e8857f35d194e71b2b5b341f89b77fd1f8b0 100644
--- a/paddle/gserver/tests/MKLDNNTester.cpp
+++ b/paddle/gserver/tests/MKLDNNTester.cpp
@@ -91,10 +91,16 @@ void MKLDNNTester::setInputImgSize() {
// init randome parameters of ref, and copy to mkldnn
void MKLDNNTester::randomWgtDatas() {
EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size());
+ const bool isBN = refLayer_->getType() == "batch_norm";
for (size_t i = 0; i < parameters_[REF].size(); ++i) {
const VectorPtr& dnnValue = parameters_[DNN][i]->getBuf(PARAMETER_VALUE);
const VectorPtr& refValue = parameters_[REF][i]->getBuf(PARAMETER_VALUE);
parameters_[REF][i]->randomize();
+ if (isBN && i == 2) {
+ // this param is moving average in batch norm, which must larger than 0
+ real offset = fabs(refValue->getMin()) + 1.0;
+ refValue->add(offset);
+ }
dnnValue->copyFrom(*refValue);
VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName();
@@ -132,8 +138,7 @@ void MKLDNNTester::checkForward() {
void MKLDNNTester::checkBackwardData() {
VLOG(MKLDNN_TESTS) << "Check Backward Data";
- // TODO(TJ): uncomment me when batch norm ready
- // const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm";
+ const bool isBN = refLayer_->getType() == "batch_norm";
for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) {
const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad();
const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad();
@@ -144,11 +149,11 @@ void MKLDNNTester::checkBackwardData() {
double delta = compareMatrix(dnnDiff, refDiff);
EXPECT_LE(fabs(delta), eps_);
- // TODO(TJ): uncomment me when batch norm ready
- // if (isBN) {
- // // the other two inputs in batch norm are for moving mean and var
- // break;
- // }
+ if (isBN) {
+ // the other two inputs in batch norm are for moving mean and var
+ // do not have grad to compare
+ break;
+ }
}
}
@@ -308,10 +313,14 @@ double MKLDNNTester::compareVector(const VectorPtr& v1, const VectorPtr& v2) {
void MKLDNNTester::runOnce() {
// test forward
randomBotDatas();
- dnnLayer_->forward(PASS_TRAIN);
- refLayer_->forward(PASS_TRAIN);
+ dnnLayer_->forward(passType_);
+ refLayer_->forward(passType_);
checkForward();
+ if (passType_ == PASS_TEST) {
+ return;
+ }
+
// test backward
// simple updater
UpdateCallback updateCallback = [](Parameter* para) {
@@ -343,6 +352,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
size_t batchSize,
size_t inputImgH,
size_t inputImgW,
+ PassType passType,
bool printDetails,
size_t iter,
float epsilon) {
@@ -361,6 +371,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
ih_ = inputImgH;
iw_ = inputImgW;
+ passType_ = passType;
log_ = printDetails;
iter_ = iter;
eps_ = epsilon;
diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h
index c385d1c72717d120211f167b5c5eb9a557da3714..19d8848f74f2ee4a809e42164a0eb180abd2a4e1 100644
--- a/paddle/gserver/tests/MKLDNNTester.h
+++ b/paddle/gserver/tests/MKLDNNTester.h
@@ -62,12 +62,15 @@ protected:
float eps_;
/// input image size, default 1
size_t ih_, iw_;
+ /// passType, PASS_TRAIN, PASS_TEST or PASS_GC (Gradient Check pass)
+ PassType passType_;
public:
explicit MKLDNNTester(size_t iter = 3, float epsilon = 1e-4) {
iter_ = iter;
eps_ = epsilon;
log_ = false;
+ passType_ = PASS_TRAIN;
}
~MKLDNNTester() {}
@@ -78,6 +81,7 @@ public:
size_t batchSize,
size_t inputImgH = 1,
size_t inputImgW = 1,
+ PassType passType = PASS_TRAIN,
bool printDetails = false,
size_t iter = 3,
float epsilon = 1e-4);
diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp
index 6cb4ca5e08eab5b979e404c9e09dcfec11086c22..85d4f437c2664135a7975c6ed3270d8f1ddbeaf4 100644
--- a/paddle/gserver/tests/test_MKLDNN.cpp
+++ b/paddle/gserver/tests/test_MKLDNN.cpp
@@ -212,6 +212,66 @@ TEST(MKLDNNLayer, PoolLayer) {
testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2});
}
+struct testBatchNormDesc {
+ int bs;
+ int ic;
+ int ih, iw;
+};
+
+static void getMKLDNNBatchNormConfig(TestConfig& cfg,
+ const testBatchNormDesc& pm) {
+ cfg.layerConfig.set_size(pm.ic * pm.ih * pm.iw);
+ cfg.layerConfig.set_type("mkldnn_batch_norm");
+ cfg.biasSize = pm.ic;
+ cfg.inputDefs.push_back(
+ {INPUT_DATA,
+ "layer_0",
+ /* size of input layer= */ size_t(pm.ic * pm.ih * pm.iw),
+ /* size of weight= */ size_t(pm.ic)});
+ cfg.inputDefs.push_back(
+ {INPUT_DATA, "layer_1_moving_mean", 1, size_t(pm.ic)});
+ cfg.inputDefs.back().isStatic = true;
+ cfg.inputDefs.push_back({INPUT_DATA, "layer_2_moving_var", 1, size_t(pm.ic)});
+ cfg.inputDefs.back().isStatic = true;
+ LayerInputConfig* input = cfg.layerConfig.add_inputs();
+ // TODO(TJ): uncomment me when refine and support comparing all zeroes vector
+ // cfg.layerConfig.set_active_type("relu");
+ cfg.layerConfig.add_inputs();
+ cfg.layerConfig.add_inputs();
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(pm.ic);
+ img_conf->set_img_size_y(pm.ih);
+ img_conf->set_img_size(pm.iw);
+}
+
+void testBatchNormLayer(const testBatchNormDesc& pm) {
+ TestConfig dnnConfig;
+ getMKLDNNBatchNormConfig(dnnConfig, pm);
+ TestConfig refConfig = dnnConfig;
+ refConfig.layerConfig.set_type("batch_norm");
+ // for PASS_TRAIN, use_global_stats always should be false, and batchsize != 1
+ VLOG(MKLDNN_TESTS) << "check train phase";
+ dnnConfig.layerConfig.set_use_global_stats(false);
+ refConfig.layerConfig.set_use_global_stats(false);
+ MKLDNNTester tester;
+ tester.run(dnnConfig, refConfig, pm.bs, pm.ih, pm.iw, PASS_TRAIN);
+ // for PASS_TEST, check use_global_stats true and false, and batchsize 1
+ VLOG(MKLDNN_TESTS) << "check test phase";
+ for (auto useGS : {false, true}) {
+ dnnConfig.layerConfig.set_use_global_stats(useGS);
+ refConfig.layerConfig.set_use_global_stats(useGS);
+ MKLDNNTester tester;
+ for (auto bs : {pm.bs, 1}) {
+ tester.run(dnnConfig, refConfig, bs, pm.ih, pm.iw, PASS_TEST);
+ }
+ }
+}
+
+TEST(MKLDNNLayer, BatchNormLayer) {
+ testBatchNormLayer({4, 10, 6, 6});
+ testBatchNormLayer({16, 32, 16, 16});
+}
+
struct testActDesc {
int bs, ic, ih, iw;
};
diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h
index fe755d096da9713e39581a909e5d21aa93d69f0f..2b62d4e11ac7276924947ab47360ffca84240aea 100644
--- a/paddle/math/MKLDNNMatrix.h
+++ b/paddle/math/MKLDNNMatrix.h
@@ -91,6 +91,11 @@ public:
const MKLDNNMatrixPtr& dst,
bool checkData = true);
+ void copyFrom(const Matrix& src) {
+ // TODO(TJ): reorder data if this format is not nchw or x
+ m_->copyFrom(src);
+ }
+
public:
/**
* Reorder this MKLDNNMatrix from other format.
diff --git a/paddle/math/RowBuffer.h b/paddle/math/RowBuffer.h
index 9ef5b89680b00981188d78cb312dc75e2c0a79ee..e457d71f1b357aecae48107688499edd7271a5db 100644
--- a/paddle/math/RowBuffer.h
+++ b/paddle/math/RowBuffer.h
@@ -60,7 +60,7 @@ public:
*/
inline real* get(int row) const {
if (preallocatedBuf_) {
- CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize());
+ CHECK_LE((row)*width_ * sizeof(real), preallocatedBuf_->getSize());
return reinterpret_cast(preallocatedBuf_->getBuf()) + row * width_;
} else {
CHECK_LE((row + 1) * width_, rowStore_.size());
diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h
index 9b36182c2b619317da31310141823442d8fd3f94..29c20e18601b71bac5201df8ff0c7ce0bed702dc 100644
--- a/paddle/memory/memcpy.h
+++ b/paddle/memory/memcpy.h
@@ -54,6 +54,5 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
#endif
-
} // namespace memory
} // namespace paddle
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index 39250480db37f95abec506ba3c9653e5fd6db788..eaa9884443386cebdf686e25143d99fec17646f2 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -89,7 +89,7 @@ function(op_library TARGET)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
-
+
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
@@ -131,6 +131,7 @@ set(DEPS_OPS
pool_op
pool_with_index_op
conv_op
+ sequence_conv_op
lstm_op)
@@ -139,10 +140,11 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
-op_library(sum_op DEPS net_op)
op_library(conv_op DEPS vol2col)
+op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
+op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
@@ -157,3 +159,4 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
+cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc
index ee4f9b0ef29cc73907bc09fb6014850cb4e58a67..90f1535fcd387c34ea39d84d9c2ec78fcbc3c764 100644
--- a/paddle/operators/activation_op.cc
+++ b/paddle/operators/activation_op.cc
@@ -446,12 +446,16 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp,
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
hard_sigmoid_grad, ops::ActivationOpGrad);
-#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
- REGISTER_OP_CPU_KERNEL( \
- act_type, \
- ops::ActivationKernel>); \
- REGISTER_OP_CPU_KERNEL(act_type##_grad, \
- ops::ActivationGradKernel>);
+#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
+ REGISTER_OP_CPU_KERNEL( \
+ act_type, \
+ ops::ActivationKernel>, \
+ ops::ActivationKernel>); \
+ REGISTER_OP_CPU_KERNEL( \
+ act_type##_grad, ops::ActivationGradKernel>, \
+ ops::ActivationGradKernel>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu
index 7b7644519d4e9cadcc4ca62ccb599262feffa660..97737857ab25dfa92163b64a750fd7a7d9ea0ac3 100644
--- a/paddle/operators/activation_op.cu
+++ b/paddle/operators/activation_op.cu
@@ -17,12 +17,16 @@
namespace ops = paddle::operators;
-#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
- REGISTER_OP_GPU_KERNEL( \
- act_type, \
- ops::ActivationKernel>); \
- REGISTER_OP_GPU_KERNEL(act_type##_grad, \
- ops::ActivationGradKernel>);
+#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
+ REGISTER_OP_GPU_KERNEL( \
+ act_type, \
+ ops::ActivationKernel>, \
+ ops::ActivationKernel>); \
+ REGISTER_OP_GPU_KERNEL( \
+ act_type##_grad, ops::ActivationGradKernel>, \
+ ops::ActivationGradKernel>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h
index 4f4eb44fedc0a89cdcf60fb7177014a11eb96048..e4c6b2e09cd71f00a2ef73173205b9066c34fcf5 100644
--- a/paddle/operators/activation_op.h
+++ b/paddle/operators/activation_op.h
@@ -210,8 +210,8 @@ struct HardShrinkFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y) const {
- auto temp1 = (x < (threshold * -1)).template cast().eval();
- auto temp2 = (x > threshold).template cast().eval();
+ auto temp1 = (x < static_cast(threshold * -1)).template cast().eval();
+ auto temp2 = (x > static_cast(threshold)).template cast().eval();
y.device(d) = x * (temp1 + temp2);
}
};
@@ -226,8 +226,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- auto temp1 = (x < (threshold * -1)).template cast().eval();
- auto temp2 = (x > threshold).template cast().eval();
+ auto temp1 = (x < static_cast(threshold * -1)).template cast().eval();
+ auto temp2 = (x > static_cast(threshold)).template cast().eval();
dx.device(d) = dy * (temp1 + temp2).template cast();
}
};
@@ -243,9 +243,10 @@ struct SoftShrinkFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- auto temp1 = (x > lambda).template cast().eval();
- auto temp2 = (x < -lambda).template cast().eval();
- y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda);
+ auto lambdaT = static_cast(lambda);
+ auto temp1 = (x > lambdaT).template cast().eval();
+ auto temp2 = (x < -lambdaT).template cast().eval();
+ y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
@@ -257,8 +258,9 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- auto temp1 = (x > lambda).template cast().eval();
- auto temp2 = (x < -lambda).template cast().eval();
+ auto lambdaT = static_cast(lambda);
+ auto temp1 = (x > lambdaT).template cast().eval();
+ auto temp2 = (x < -lambdaT).template cast().eval();
dx.device(d) = dy * (temp1 + temp2).template cast();
}
};
@@ -362,7 +364,8 @@ struct BReluFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max);
+ y.device(d) =
+ x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max));
}
};
@@ -375,7 +378,9 @@ struct BReluGradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast();
+ dx.device(d) = dy *
+ ((x > static_cast(t_min)) * (x < static_cast(t_max)))
+ .template cast();
}
};
@@ -390,7 +395,8 @@ struct Relu6Functor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- y.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(threshold);
+ y.device(d) =
+ x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold));
}
};
@@ -402,8 +408,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- dx.device(d) =
- dy * ((x > static_cast(0)) * (x < threshold)).template cast();
+ dx.device(d) = dy *
+ ((x > static_cast(0)) * (x < static_cast(threshold)))
+ .template cast();
}
};
@@ -463,7 +470,8 @@ struct SoftReluFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- auto temp = x.cwiseMax(-threshold).cwiseMin(threshold);
+ auto tmp = static_cast(threshold);
+ auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
y.device(d) = (static_cast