提交 bf3ae063 编写于 作者: C chengduoZH

remove conflict

INCLUDE(ExternalProject) include(ExternalProject)
SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl) set(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
include_directories(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO) if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies # If we use DSO, we do not build nccl, just download the dependencies
...@@ -12,9 +11,11 @@ if(WITH_DSO) ...@@ -12,9 +11,11 @@ if(WITH_DSO)
set(NCCL_INSTALL_DIR "") set(NCCL_INSTALL_DIR "")
else() else()
# otherwise, we build nccl and link it. # otherwise, we build nccl and link it.
set(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
# Note: cuda 8.0 is needed to make nccl
# When cuda is not installed on the system directory, need to set CUDA_HOME to your cuda root
set(NCCL_BUILD_COMMAND "make -j 8") set(NCCL_BUILD_COMMAND "make -j 8")
set(NCCL_INSTALL_COMMAND "make install") set(NCCL_INSTALL_COMMAND "make install PREFIX=${NCCL_INSTALL_DIR}")
SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
endif() endif()
ExternalProject_Add( ExternalProject_Add(
...@@ -31,20 +32,18 @@ ExternalProject_Add( ...@@ -31,20 +32,18 @@ ExternalProject_Add(
TEST_COMMAND "" TEST_COMMAND ""
) )
if (WITH_DSO) if(WITH_DSO)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0") if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_nccl_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";") file(WRITE ${dummyfile} "const char * dummy_nccl = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile}) add_library(nccl STATIC ${dummyfile})
else() else()
add_library(nccl INTERFACE) add_library(nccl INTERFACE)
endif() endif()
else() else()
ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL) add_library(nccl STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION set_property(TARGET nccl PROPERTY IMPORTED_LOCATION
${NCCL_INSTALL_DIR}/lib/libnccl.a) ${NCCL_INSTALL_DIR}/lib/libnccl_static.a)
endif() endif()
add_dependencies(nccl extern_nccl) add_dependencies(nccl extern_nccl)
LIST(APPEND external_project_dependencies nccl)
...@@ -2,35 +2,35 @@ ...@@ -2,35 +2,35 @@
## Motivation ## Motivation
The model is the output of training process. One complete model consists of two parts, namely, the **topology** and the **parameters**. To support industrial deployment, we need to make the model format must be self-completed and do not expose any training source code. A model is an output of the training process. One complete model consists of two parts, the **topology** and the **parameters**. In order to support industrial deployment, the model format must be self-complete and must not expose any training source code.
As a result, In PaddlePaddle, the **topology** represents as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model, we must support large size parameter, and efficient serialization/deserialization. As a result, In PaddlePaddle, the **topology** is represented as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model. We must support large size parameters and efficient serialization/deserialization of parameters.
## Implementation ## Implementation
The topology is saved as a plain text, in detail, a self-contain protobuf file. The topology is saved as a plain text in a detailed self-contain protobuf file.
The parameters are saved as a binary file. As we all know, the protobuf message has the limits of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We do a (benchmark experiment)[https://github.com/PaddlePaddle/Paddle/pull/4610], its result shows protobuf is not fit in this scene. The parameters are saved as a binary file. As we all know, the protobuf message has a limit of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We have done a [benchmark experiment](https://github.com/PaddlePaddle/Paddle/pull/4610), which shows that protobuf is not fit for the task.
As a result, we design a particular format for tensor serialization. By default, arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of (LoDTensorDesc)[https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99]. We save the DescProto as the byte string header, it contains the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). Tensor stores value in a continuous memory buffer, for speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is, As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**| The table below shows a tensor's byte view in detail. Note that all the signed values are written in the little-endian format.
|field name | type | description |
| --- | --- | --- |
| version | uint32_t | Version of saved file. Always 0 now. |
| tensor desc length | uint32_t | TensorDesc(Protobuf message) length in bytes. |
| tensor desc | void* | TensorDesc protobuf binary message |
| tensor data | void* | Tensor's data in binary format. The length of `tensor_data` is decided by `TensorDesc.dims()` and `TensorDesc.data_type()` |
| lod_level | uint64_t | Level of LoD |
| length of lod[0] | uint64_t | [Optional] length of lod[0] in bytes. |
| data of lod[0] | uint64_t* | [Optional] lod[0].data() |
| ... | ... | ... |
In detail, tensor's byte view as the table shows. Note that all the signed value written in little-endian.
```text
[offset] [type] [description]
0004 4 bytes integer HeaderLength, the length of LoDTensorDesc
0008 4 bytes integer ContentLength, the length of LodTensor Buffer
0009 1 bytes char TensorDesc
00010 1 bytes char TensorDesc
...
00100 1 bytes char TensorValue
00101 1 bytes char TensorValue
00102 1 bytes char TensorValue ..
...
```
## Summary ## Summary
We introduce the model format, the `ProgramDesc` describe the **topology**, and a bunch of particular format binary tensors describes the **parameters**. - We introduce a model format.
- The model represented by its forward-pass computation procedure is saved in a **ProgramDesc** protobuf message.
- A bunch of specified format binary tensors describe the **parameters**.
# Regularization in PaddlePaddle # Regularization in PaddlePaddle
## Introduction to Regularization ## Introduction to Regularization
A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. Many strategies are used by machine learning practitioners to reduce the test error, possibly at the expense of increased training error. These strategies are collectively known as **regularization**. A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. A frequently faced problem is the problem of **overfitting**, where the model does not make reliable predictions on new unseen data. **Regularization** is the process of introducing additional information in order to prevent overfitting. This is usually done by adding extra penalties to the loss function that restricts the parameter spaces that an optimization algorithm can explore.
### Parameter Norm Penalties ### Parameter Norm Penalties
Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows: Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows:
...@@ -18,52 +18,21 @@ The most commonly used norm penalties are the L2 norm penalty and the L1 norm pe ...@@ -18,52 +18,21 @@ The most commonly used norm penalties are the L2 norm penalty and the L1 norm pe
##### L1 Regularization ##### L1 Regularization
<img src="./images/l1_regularization.png" align="center"/><br/> <img src="./images/l1_regularization.png" align="center"/><br/>
A much more detailed mathematical background of reguilarization can be found [here](http://www.deeplearningbook.org/contents/regularization.html). A much more detailed mathematical background of regularization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
## Regularization Survey
## How to do Regularization in PaddlePaddle A detailed survey of regularization in various deep learning frameworks can be found [here](https://github.com/PaddlePaddle/Paddle/wiki/Regularization-Survey).
On surveying existing frameworks like Tensorflow, PyTorch, Caffe, etc, it can be seen that there are 2 common approaches of doing regularization:
1. Making regularization a part of the optimizer using an attribute like `weight_decay` that is used to control the scale of the L2 Penalty. This approach is used in PyTorch as follows:
```python
opt = torch.optim.SGD(params, lr=0.2, weight_decay=0.2)
```
At every optimization step, this code will add the gradient of the L2 Norm of the params to the gradient of the params with respect to the loss function. This can seen in the following code snippet:
```python
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
```
This is a very restyrictive way of doing regularization and does not give the users enough flexibility.
**Advantages**:
- It is easy to implement for us.
- Faster execution of backward. However, it can be done manually by advanced users too.
**Disadvantages**:
- Not flexible for other regularizations such as L1/L0 regularization.
- Does not allow for different regularization coefficient for different parameters. For example, in most models, ony the weight matrices are regularized and the bias vectors are unregularized.
- Tightly coupled optimizer and regularization implementation.
2. Adding regularization ops to the graph through Python API. This approach is used by Tensorflow and Caffe. Using this approach, we manually add regularization ops to the graph and then add the regularization loss to the final loss function before sending them to the optimizer.
**Advantages**:
- Allows for greater flexibility to the users of Paddle. Using this approach, the users can put different regularization to different parameters and also choose parameters that are not a part of regularization.
- Makes it easy for the users to customize and extend the framework.
**Disadvantages**:
- Implementation requires comprehensive design and time.
## Proposal for Regularization in PaddlePaddle ## Proposal for Regularization in PaddlePaddle
### Low-Level implementation ### Low-Level implementation
In the new design, we propose to create new operations for regularization. For now, we can add 2 ops thgat correspond to the most frequently used regularizations: In the new design, we propose to create new operations for regularization. For now, we can add 2 ops that correspond to the most frequently used regularizations:
- L2_regularization_op - L2_regularization_op
- L1_regularization_op - L1_regularization_op
These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate Cpu and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes. other than L1 and L2 norm penalties. These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate CPU and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes other than L1 and L2 norm penalties.
The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API. The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API.
...@@ -94,7 +63,7 @@ Since we want to create the regularization ops in a lazy manner, the regularizat ...@@ -94,7 +63,7 @@ Since we want to create the regularization ops in a lazy manner, the regularizat
#### High-level API #### High-level API
In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we lso need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers). In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we also need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
......
...@@ -25,9 +25,8 @@ import ( ...@@ -25,9 +25,8 @@ import (
"strings" "strings"
"time" "time"
log "github.com/inconshreveable/log15"
"github.com/namsral/flag" "github.com/namsral/flag"
log "github.com/sirupsen/logrus"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
...@@ -41,16 +40,20 @@ func main() { ...@@ -41,16 +40,20 @@ func main() {
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warn, error, crit")
flag.Parse() flag.Parse()
level, e := log.ParseLevel(*logLevel) lvl, err := log.LvlFromString(*logLevel)
candy.Must(e) if err != nil {
panic(err)
}
log.SetLevel(level) log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
if *endpoints == "" { if *endpoints == "" {
log.Warningln("-endpoints not set, fault tolerance not be enabled.") log.Warn("-endpoints not set, fault tolerance not be enabled.")
} }
var store master.Store var store master.Store
...@@ -58,23 +61,25 @@ func main() { ...@@ -58,23 +61,25 @@ func main() {
eps := strings.Split(*endpoints, ",") eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP() ip, err := networkhelper.GetExternalIP()
if err != nil { if err != nil {
log.Fatal(err) log.Crit("get external ip error", log.Ctx{"error": err})
panic(err)
} }
addr := fmt.Sprintf("%s:%d", ip, *port) addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error creating etcd client.", log.Ctx{"error": err})
panic(err)
} }
} else { } else {
store = &master.InMemStore{} store = &master.InMemStore{}
} }
shutdown := func() { shutdown := func() {
log.Infoln("shutting down gracefully") log.Info("shutting down gracefully")
err := store.Shutdown() err := store.Shutdown()
if err != nil { if err != nil {
log.Errorln(err) log.Error("shutdown error", log.Ctx{"error": err})
} }
} }
...@@ -86,24 +91,28 @@ func main() { ...@@ -86,24 +91,28 @@ func main() {
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error creating new service.", log.Ctx{"error": err})
panic(err)
} }
err = rpc.Register(s) err = rpc.Register(s)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error registering to etcd.", log.Ctx{"error": err})
panic(err)
} }
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
panic(err)
} }
go func() { go func() {
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error serving HTTP", log.Ctx{"error": err})
panic(err)
} }
}() }()
......
...@@ -27,11 +27,11 @@ import ( ...@@ -27,11 +27,11 @@ import (
"github.com/topicai/candy" "github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 8001, "port of the pserver")
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry") index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
...@@ -41,13 +41,17 @@ func main() { ...@@ -41,13 +41,17 @@ func main() {
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warn, error, crit")
flag.Parse() flag.Parse()
level, err := log.ParseLevel(*logLevel) lvl, err := log.LvlFromString(*logLevel)
candy.Must(err) if err != nil {
panic(err)
}
log.SetLevel(level) log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
var idx int var idx int
...@@ -63,7 +67,7 @@ func main() { ...@@ -63,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx) cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil { if err != nil {
if err == pserver.ErrCheckpointNotFound { if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.") log.Info("load checkpoint error", "error", err)
} else { } else {
panic(err) panic(err)
} }
...@@ -71,10 +75,10 @@ func main() { ...@@ -71,10 +75,10 @@ func main() {
} }
shutdown := func() { shutdown := func() {
log.Infoln("shutting down gracefully") log.Info("shutting down gracefully")
sErr := e.Shutdown() sErr := e.Shutdown()
if sErr != nil { if sErr != nil {
log.Errorln(sErr) log.Error("error shutting down", log.Ctx{"error": sErr})
} }
} }
...@@ -95,7 +99,7 @@ func main() { ...@@ -95,7 +99,7 @@ func main() {
candy.Must(err) candy.Must(err)
go func() { go func() {
log.Infof("start pserver at port %d", *port) log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil) err = http.Serve(l, nil)
candy.Must(err) candy.Must(err)
}() }()
......
hash: 328e7b9b7306b45e7b9879139a9f86698115981f6283032e1312093a6a6ddb04 hash: 51d9e2e46d7fd9173ff11ecada40f7b7728756be18d5e2f032535f66465e6e15
updated: 2017-10-16T08:00:23.484693528Z updated: 2017-10-24T15:04:09.987751592-07:00
imports: imports:
- name: github.com/alecthomas/gometalinter - name: github.com/alecthomas/gometalinter
version: bae2f1293d092fd8167939d5108d1b025eaef9de version: bae2f1293d092fd8167939d5108d1b025eaef9de
...@@ -99,6 +99,8 @@ imports: ...@@ -99,6 +99,8 @@ imports:
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml - name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7 version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/go-stack/stack
version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf
- name: github.com/gogo/protobuf - name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages: subpackages:
...@@ -120,8 +122,14 @@ imports: ...@@ -120,8 +122,14 @@ imports:
- runtime - runtime
- runtime/internal - runtime/internal
- utilities - utilities
- name: github.com/inconshreveable/log15
version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3
- name: github.com/jonboulle/clockwork - name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09 version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/mattn/go-colorable
version: 5411d3eea5978e6cdc258b30de592b60df6aba96
- name: github.com/mattn/go-isatty
version: 57fdcb988a5c543893cc61bce354a6e24ab70022
- name: github.com/matttproud/golang_protobuf_extensions - name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages: subpackages:
...@@ -179,11 +187,12 @@ imports: ...@@ -179,11 +187,12 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: 0f826bdd13b500be0f1d4004938ad978fcc6031e version: e48874b42435b4347fc52bdee0424a52abc974d7
repo: https://github.com/golang/sys.git repo: https://github.com/golang/sys.git
vcs: git vcs: git
subpackages: subpackages:
- unix - unix
- windows
- name: golang.org/x/text - name: golang.org/x/text
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720 version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git repo: https://github.com/golang/text.git
...@@ -222,4 +231,3 @@ testImports: ...@@ -222,4 +231,3 @@ testImports:
version: 05e8a0eda380579888eb53c394909df027f06991 version: 05e8a0eda380579888eb53c394909df027f06991
subpackages: subpackages:
- assert - assert
...@@ -26,3 +26,7 @@ import: ...@@ -26,3 +26,7 @@ import:
version: v1.1.0 version: v1.1.0
- package: github.com/alecthomas/gometalinter - package: github.com/alecthomas/gometalinter
version: v1.2.1 version: v1.2.1
- package: github.com/inconshreveable/log15
version: v2.13
- package: github.com/go-stack/stack
version: v1.6.0
...@@ -35,13 +35,19 @@ import ( ...@@ -35,13 +35,19 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client) var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client var curHandle C.paddle_master_client
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
func add(c *master.Client) C.paddle_master_client { func add(c *master.Client) C.paddle_master_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
...@@ -117,7 +123,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -117,7 +123,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
} }
err := c.SetDataset(paths) err := c.SetDataset(paths)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error set dataset", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR return C.PADDLE_MASTER_ERROR
} }
...@@ -167,7 +173,7 @@ func paddle_request_save_model(client C.paddle_master_client, trainerID string, ...@@ -167,7 +173,7 @@ func paddle_request_save_model(client C.paddle_master_client, trainerID string,
c := get(client) c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond) need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error request save model", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR return C.PADDLE_MASTER_ERROR
} }
......
...@@ -21,7 +21,7 @@ import ( ...@@ -21,7 +21,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// Client is the client of the master server. // Client is the client of the master server.
...@@ -75,7 +75,7 @@ func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error { ...@@ -75,7 +75,7 @@ func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
for { for {
err := f() err := f()
if err != nil { if err != nil {
log.Warningln(err) log.Warn("create etcd client error", log.Ctx{"error": err})
} else { } else {
break break
} }
...@@ -135,13 +135,13 @@ func (c *Client) getRecords(passID int) { ...@@ -135,13 +135,13 @@ func (c *Client) getRecords(passID int) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
log.Errorf("getTask error: %s", err) log.Error("getTask error.", log.Ctx{"error": err})
} }
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, e := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
if e != nil { if e != nil {
log.Errorln(e) log.Error("error open chunk", log.Ctx{"error": e})
continue continue
} }
...@@ -152,12 +152,15 @@ func (c *Client) getRecords(passID int) { ...@@ -152,12 +152,15 @@ func (c *Client) getRecords(passID int) {
if s.Err() != nil { if s.Err() != nil {
c.ch <- record{nil, s.Err()} c.ch <- record{nil, s.Err()}
log.Errorln(err, chunk.Path) log.Error(
"error scan chunk",
log.Ctx{"error": err, "path": chunk.Path},
)
} }
err = f.Close() err = f.Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error close record file", log.Ctx{"error": err})
} }
} }
...@@ -166,7 +169,7 @@ func (c *Client) getRecords(passID int) { ...@@ -166,7 +169,7 @@ func (c *Client) getRecords(passID int) {
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
err = c.taskFinished(t.Meta.ID) err = c.taskFinished(t.Meta.ID)
if err != nil { if err != nil {
log.Errorln(err) log.Error("task finish callback error.", log.Ctx{"error": err})
} }
} }
} }
...@@ -179,12 +182,12 @@ func (c *Client) monitorMaster(addrCh <-chan string) { ...@@ -179,12 +182,12 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("close old master addr error", log.Ctx{"error": err})
} }
} else { } else {
err := c.conn.Connect(curMaster) err := c.conn.Connect(curMaster)
if err != nil { if err != nil {
log.Errorln(err) log.Error("connect to new master addr error", log.Ctx{"error": err})
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
......
...@@ -25,8 +25,6 @@ import ( ...@@ -25,8 +25,6 @@ import (
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
...@@ -36,10 +34,6 @@ const ( ...@@ -36,10 +34,6 @@ const (
chunkPerTask = 10 chunkPerTask = 10
) )
func init() {
log.SetLevel(log.ErrorLevel)
}
func TestGetFinishTask(t *testing.T) { func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0" const path = "/tmp/master_client_test_0"
......
...@@ -20,7 +20,7 @@ import ( ...@@ -20,7 +20,7 @@ import (
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -44,7 +44,7 @@ type EtcdClient struct { ...@@ -44,7 +44,7 @@ type EtcdClient struct {
// NewEtcdClient creates a new EtcdClient. // NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints) log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
DialTimeout: dialTimeout, DialTimeout: dialTimeout,
...@@ -64,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -64,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Infof("Trying to acquire lock at %s.", lockPath) log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
err = lock.Lock(context.TODO()) err = lock.Lock(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Infof("Successfully acquired lock at %s.", lockPath) log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
put := clientv3.OpPut(addrPath, addr) put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
...@@ -78,7 +78,8 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat ...@@ -78,7 +78,8 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Fatal("No longer owns the master lock. Exiting.") log.Crit("No longer owns the master lock. Exiting.")
panic("No longer owns the master lock. Exiting.")
} }
e := &EtcdClient{ e := &EtcdClient{
...@@ -102,7 +103,7 @@ func (e *EtcdClient) Save(state []byte) error { ...@@ -102,7 +103,7 @@ func (e *EtcdClient) Save(state []byte) error {
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock again") log.Error("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx) err := e.lock.Lock(ctx)
cancel() cancel()
...@@ -116,9 +117,10 @@ func (e *EtcdClient) Save(state []byte) error { ...@@ -116,9 +117,10 @@ func (e *EtcdClient) Save(state []byte) error {
// to kill current master server. The current // to kill current master server. The current
// state is not saved, but the trainer's RPC // state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry. // call will fail, so the trainer will retry.
log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err) log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
panic("Could not acquire the lock at %s: %v. Exiting.")
} }
log.Infof("Successfully acquired lock at %s.", e.lockPath) log.Info("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state) return e.Save(state)
} }
...@@ -136,7 +138,7 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -136,7 +138,7 @@ func (e *EtcdClient) Load() ([]byte, error) {
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock and load again.") log.Error("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background()) err = e.lock.Lock(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -163,7 +165,7 @@ func (e *EtcdClient) Shutdown() error { ...@@ -163,7 +165,7 @@ func (e *EtcdClient) Shutdown() error {
if err == nil { if err == nil {
err = newErr err = newErr
} else { } else {
log.Errorln(newErr) log.Error("shutdown error", log.Ctx{"error": newErr})
} }
} }
...@@ -192,7 +194,7 @@ func watchKey(c *clientv3.Client, key string, valChan chan<- string) { ...@@ -192,7 +194,7 @@ func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
for wresp := range rch { for wresp := range rch {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string // if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value) log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
valChan <- string(ev.Kv.Value) valChan <- string(ev.Kv.Value)
} }
} }
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
...@@ -170,11 +170,11 @@ func (s *Service) recover() (bool, error) { ...@@ -170,11 +170,11 @@ func (s *Service) recover() (bool, error) {
} }
if state == nil { if state == nil {
log.Infoln("No state exists, not recovered.") log.Info("No state exists, not recovered.")
return false, nil return false, nil
} }
log.Infof("Loaded snapshot of size: %d bytes.", len(state)) log.Info("Loaded snapshot.", log.Ctx{"size": len(state)})
gr, err := gzip.NewReader(bytes.NewReader(state)) gr, err := gzip.NewReader(bytes.NewReader(state))
if err != nil { if err != nil {
return false, err return false, err
...@@ -191,11 +191,11 @@ func (s *Service) recover() (bool, error) { ...@@ -191,11 +191,11 @@ func (s *Service) recover() (bool, error) {
if err != nil { if err != nil {
// Only close failed, recover actually succeed, so // Only close failed, recover actually succeed, so
// just log error. // just log error.
log.Errorln(err) log.Error("error close recover file.", log.Ctx{"error": err})
} }
s.state = tqs s.state = tqs
log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.") log.Info("Master recovered from snapshot, scheduling pending task timeout check.", s.logCtx())
for _, t := range s.state.Pending { for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
} }
...@@ -224,7 +224,7 @@ func (s *Service) snapshot() error { ...@@ -224,7 +224,7 @@ func (s *Service) snapshot() error {
} }
state := buf.Bytes() state := buf.Bytes()
log.Infof("Saving snapshot of size: %d bytes.", len(state)) log.Info("Saving snapshot.", log.Ctx{"size bytes": len(state)})
return s.store.Save(state) return s.store.Save(state)
} }
...@@ -260,7 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -260,7 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
} }
count := index.NumChunks() count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count) log.Info("reading chunks.", log.Ctx{"path": path, "num chunks": count})
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunk := Chunk{ chunk := Chunk{
Path: path, Path: path,
...@@ -300,7 +300,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error { ...@@ -300,7 +300,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
return err return err
} }
close(s.ready) close(s.ready)
...@@ -320,7 +320,7 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { ...@@ -320,7 +320,7 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
defer func() { defer func() {
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
} }
}() }()
...@@ -328,12 +328,12 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { ...@@ -328,12 +328,12 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
t.NumFailure++ t.NumFailure++
if t.NumFailure > s.failureMax { if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warn("Task failed to many times, discard.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Failed = append(s.state.Failed, t) s.state.Failed = append(s.state.Failed, t)
return return
} }
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure) log.Warn("Task failed, re-dispatch.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Todo = append(s.state.Todo, t) s.state.Todo = append(s.state.Todo, t)
return return
} }
...@@ -353,8 +353,8 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -353,8 +353,8 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
} }
// must be called with lock held. // must be called with lock held.
func (s *Service) logFields() log.Fields { func (s *Service) logCtx() log.Ctx {
return log.Fields{ return log.Ctx{
"todoLen": len(s.state.Todo), "todoLen": len(s.state.Todo),
"pendingLen": len(s.state.Pending), "pendingLen": len(s.state.Pending),
"doneLen": len(s.state.Done), "doneLen": len(s.state.Done),
...@@ -383,10 +383,10 @@ func (s *Service) GetTask(passID int, task *Task) error { ...@@ -383,10 +383,10 @@ func (s *Service) GetTask(passID int, task *Task) error {
if len(s.state.Todo) == 0 { if len(s.state.Todo) == 0 {
if len(s.state.Done) == 0 && len(s.state.Pending) == 0 { if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass") log.Warn("All tasks failed, may start next pass", s.logCtx())
return ErrAllTaskFailed return ErrAllTaskFailed
} }
log.WithFields(s.logFields()).Warningln("No more available task.") log.Warn("No more available task.", s.logCtx())
return ErrNoMoreAvailable return ErrNoMoreAvailable
} }
...@@ -400,8 +400,9 @@ func (s *Service) GetTask(passID int, task *Task) error { ...@@ -400,8 +400,9 @@ func (s *Service) GetTask(passID int, task *Task) error {
} }
*task = t.Task *task = t.Task
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta) ctx := s.logCtx()
ctx["task meta"] = t.Task.Meta
log.Info("Task dispatched.", ctx)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil return nil
} }
...@@ -417,7 +418,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -417,7 +418,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.state.Pending[taskID] t, ok := s.state.Pending[taskID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) ctx := s.logCtx()
ctx["task id"] = taskID
log.Warn("Pending task not found.", ctx)
return nil return nil
} }
...@@ -426,7 +429,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -426,7 +429,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = append(s.state.Done, t) s.state.Done = append(s.state.Done, t)
delete(s.state.Pending, taskID) delete(s.state.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) ctx := s.logCtx()
ctx["task id"] = taskID
log.Info("Task finished.", ctx)
if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 { if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
// increase master side pass count if all tasks finished // increase master side pass count if all tasks finished
s.state.CurPass++ s.state.CurPass++
...@@ -434,12 +439,14 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -434,12 +439,14 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = []taskEntry{} s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks // TODO(typhoonzero): deal with failed tasks
s.state.Failed = []taskEntry{} s.state.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass) ctx := s.logCtx()
ctx["new pass"] = s.state.CurPass
log.Warn("all task finished, add new pass data.", ctx)
} }
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
} }
return err return err
} }
...@@ -455,7 +462,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { ...@@ -455,7 +462,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
t, ok := s.state.Pending[meta.ID] t, ok := s.state.Pending[meta.ID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta) log.Warn("TaskFailed:Pending task not found.", log.Ctx{"task": t.Task.Meta})
return nil return nil
} }
......
...@@ -45,9 +45,15 @@ import ( ...@@ -45,9 +45,15 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client" "github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client) var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client var curHandle C.paddle_pserver_client
...@@ -164,10 +170,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, ...@@ -164,10 +170,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name) log.Warn(
"parameter already initialized, treat paddle_init_param as successful.",
log.Ctx{"parameter": name},
)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Error("error init param", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -180,11 +189,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int { ...@@ -180,11 +189,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.") log.Warn("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Error("error finish init params", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -205,7 +214,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient ...@@ -205,7 +214,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient
c := get(client) c := get(client)
err := c.SendGrads(gs) err := c.SendGrads(gs)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error send grads", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -222,7 +231,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -222,7 +231,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
c := get(client) c := get(client)
ps, err := c.GetParams(ns) ps, err := c.GetParams(ns)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error get params", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -231,7 +240,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -231,7 +240,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Error(
"pserver returned wrong number of parameters.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -241,7 +256,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -241,7 +256,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Error(
"pserver returned wrong parameters, or not in requested order.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }
...@@ -251,13 +272,19 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -251,13 +272,19 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nil { if unsafe.Pointer(param) == nil {
log.Errorln("must pre-allocate parameter.") log.Error("must pre-allocate parameter.")
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
if unsafe.Pointer(param.content) != nil { if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) { if int(param.content_len) != len(p.Content) {
log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) log.Error(
"the pre-allocated content len does not match parameter content len.",
log.Ctx{
"Pre-allocated len": param.content_len,
"Returned len": len(p.Content),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
...@@ -84,7 +84,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -84,7 +84,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
if curServers[i].Addr == "" { if curServers[i].Addr == "" {
err := c.pservers[i].Close() err := c.pservers[i].Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error closing connection to pserver", log.Ctx{"error": err})
} }
continue continue
...@@ -92,7 +92,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -92,7 +92,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
err := c.pservers[i].Connect(curServers[i].Addr) err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error connecting to pserver", log.Ctx{"error": err})
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
......
...@@ -30,7 +30,7 @@ import ( ...@@ -30,7 +30,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client" "github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -90,7 +90,7 @@ func initEtcdClient() { ...@@ -90,7 +90,7 @@ func initEtcdClient() {
DialTimeout: time.Second * time.Duration(1), DialTimeout: time.Second * time.Duration(1),
}) })
if err != nil { if err != nil {
log.Errorf("err %v", err) log.Error("error init etcd client", log.Ctx{"error": err})
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err = client.Delete(ctx, pserver.PsDesired) _, err = client.Delete(ctx, pserver.PsDesired)
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -54,26 +54,29 @@ func (e *Etcd) Desired() int { ...@@ -54,26 +54,29 @@ func (e *Etcd) Desired() int {
resp, err := e.client.Get(ctx, pserver.PsDesired) resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) log.Error(
"Get ps dresire number failed! reconnecting...",
log.Ctx{"error": err},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...") log.Info("Waiting for ps desired registered ...")
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %d invalid %v", psDesired, err) log.Error("atoi failed", log.Ctx{"error": err})
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("Get psDesired number: %d", psDesired) log.Debug("Got psDesired", log.Ctx{"psDesired": psDesired})
break break
} }
return psDesired return psDesired
...@@ -88,17 +91,20 @@ func (e *Etcd) List() []Server { ...@@ -88,17 +91,20 @@ func (e *Etcd) List() []Server {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debug("looking for pserver", log.Ctx{"ps key": psKey})
resp, err := e.client.Get(ctx, psKey) resp, err := e.client.Get(ctx, psKey)
cancel() cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Info(
"Get psKey error",
log.Ctx{"ps key": psKey, "error": err},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...") log.Info("Waiting for ps addr registered ...")
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -106,11 +112,17 @@ func (e *Etcd) List() []Server { ...@@ -106,11 +112,17 @@ func (e *Etcd) List() []Server {
psAddr := string(resp.Kvs[0].Value) psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address // TODO(Longfei) check the ps address
if psAddr == "" { if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey) log.Info(
"Value under psKey is empty",
log.Ctx{"psKey": psKey},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("got value (%s) for key: %s", psAddr, psKey) log.Debug(
"got psAddr given psKey",
log.Ctx{"psAddr": psAddr, "psKey": psKey},
)
servers[i].Index = i servers[i].Index = i
servers[i].Addr = psAddr servers[i].Addr = psAddr
} }
...@@ -130,13 +142,13 @@ func NewEtcd(endpoints string) *Etcd { ...@@ -130,13 +142,13 @@ func NewEtcd(endpoints string) *Etcd {
DialTimeout: defaultEtcdTimeout, DialTimeout: defaultEtcdTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("Init etcd connection failed: %v", err) log.Error("Init etcd connection failed", log.Ctx{"error": err})
time.Sleep(defaultEtcdTimeout) time.Sleep(defaultEtcdTimeout)
continue continue
} }
break break
} }
log.Infof("Connected to etcd: %s\n", endpoints) log.Info("Connected to etcd endpoint", log.Ctx{"endpoint": endpoints})
client := &Etcd{ client := &Etcd{
client: cli, client: cli,
timeout: defaultEtcdTimeout, timeout: defaultEtcdTimeout,
...@@ -154,7 +166,7 @@ func (e *Etcd) Select() (bool, error) { ...@@ -154,7 +166,7 @@ func (e *Etcd) Select() (bool, error) {
} }
lock := concurrency.NewMutex(sess, initLockPath) lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath) log.Info("Trying to acquire lock", log.Ctx{"lock path": initLockPath})
// Do not use timeout context here, since we don't know how // Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the // long does it take for other trainers to initialize the
// parameters. // parameters.
...@@ -162,7 +174,7 @@ func (e *Etcd) Select() (bool, error) { ...@@ -162,7 +174,7 @@ func (e *Etcd) Select() (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
log.Infof("Successfully acquired lock at %s.", initLockPath) log.Info("Successfully acquired lock", log.Ctx{"lock path": initLockPath})
get := clientv3.OpGet(initDonePath) get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
...@@ -181,17 +193,17 @@ func (e *Etcd) Select() (bool, error) { ...@@ -181,17 +193,17 @@ func (e *Etcd) Select() (bool, error) {
if len(resp.Kvs) == 0 { if len(resp.Kvs) == 0 {
// Key value not set, select current trainer. // Key value not set, select current trainer.
e.lock = lock e.lock = lock
log.Infoln("Trainer selected.") log.Info("Trainer selected.")
return true, nil return true, nil
} }
if string(resp.Kvs[0].Value) == initDoneVal { if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.") log.Info("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout) ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx) err = lock.Unlock(ctx)
cancel() cancel()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error unlocking", log.Ctx{"error": err})
} }
return false, nil return false, nil
} }
...@@ -221,7 +233,7 @@ func (e *Etcd) Done() error { ...@@ -221,7 +233,7 @@ func (e *Etcd) Done() error {
err = e.lock.Unlock(ctx) err = e.lock.Unlock(ctx)
cancel() cancel()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error unlocking", log.Ctx{"error": err})
} else { } else {
e.lock = nil e.lock = nil
} }
...@@ -244,7 +256,7 @@ func (e *Etcd) Close() error { ...@@ -244,7 +256,7 @@ func (e *Etcd) Close() error {
cErr := e.client.Close() cErr := e.client.Close()
if cErr != nil { if cErr != nil {
if err != nil { if err != nil {
log.Errorln(cErr) log.Error("error closing etcd client", log.Ctx{"error": cErr})
return err return err
} }
return cErr return cErr
......
...@@ -24,7 +24,7 @@ import ( ...@@ -24,7 +24,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
...@@ -82,19 +82,19 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -82,19 +82,19 @@ func (e *EtcdClient) Register(port int) (int, error) {
DialTimeout: e.dialTimeout, DialTimeout: e.dialTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("connect to etcd error: %v", err) log.Error("connect to etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.client = cli e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec)) sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil { if err != nil {
log.Errorf("create etcd session error: %v", err) log.Error("create etcd session error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.sess = sess e.sess = sess
log.Debugf("inited client to %s", e.endpoints) log.Debug("connected to etcd", log.Ctx{"endpoint": e.endpoints})
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write // init /ps_desired using transaction, for multiple pservers may want to write
...@@ -104,7 +104,7 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -104,7 +104,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
_, err := e.initDesiredPservers(ctx, e.numPservers) _, err := e.initDesiredPservers(ctx, e.numPservers)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn("pserver init error", log.Ctx{"error": err, "num pservers": e.numPservers})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
...@@ -119,14 +119,17 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -119,14 +119,17 @@ func (e *EtcdClient) Register(port int) (int, error) {
resp, err := e.client.Get(ctx, PsDesired) resp, err := e.client.Get(ctx, PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err) log.Error("get etcd key error", log.Ctx{"key": PsDesired, "error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
if len(resp.Kvs) != 0 { if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err) log.Error(
"psDesired atoi error",
log.Ctx{"error": err, "value": string(resp.Kvs[0].Value)},
)
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change // NOTE: wait util ps_desired value change
continue continue
...@@ -143,7 +146,7 @@ func (e *EtcdClient) Register(port int) (int, error) { ...@@ -143,7 +146,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
pserverIdx, err = e.registerPserverEtcd(ctx, port) pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn("register pserver on etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
...@@ -170,16 +173,17 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er ...@@ -170,16 +173,17 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey) ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debug(
"register pserver got value",
log.Ctx{"value": ps, "key": psKey},
)
if ps == "" { if ps == "" {
// find the first id and write info // find the first id and write info
pserverAddr := e.externalIP + ":" + strconv.Itoa(port) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease())) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) log.Debug("register finished", log.Ctx{"key": psKey, "value": pserverAddr})
log.Debug("register finished")
idx = i idx = i
registered = true registered = true
break break
...@@ -239,7 +243,7 @@ func (e *EtcdClient) Shutdown() error { ...@@ -239,7 +243,7 @@ func (e *EtcdClient) Shutdown() error {
newErr := e.client.Close() newErr := e.client.Close()
if newErr != nil { if newErr != nil {
if err != nil { if err != nil {
log.Errorln(newErr) log.Error("shutdown error", log.Ctx{"error": newErr})
} else { } else {
err = newErr err = newErr
} }
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"fmt" "fmt"
"unsafe" "unsafe"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
type optimizer struct { type optimizer struct {
...@@ -56,12 +56,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -56,12 +56,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State s := State
paramBufferSize := C.size_t(len(p.Content)) paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{ log.Info("New Optimizer Created with config", log.Ctx{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": paramBufferSize, "ParamSize": paramBufferSize,
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s), "StateSize": len(s),
}).Info("New Optimizer Created with config:") })
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(paramBufferSize) cbuffer = C.malloc(paramBufferSize)
...@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0]) cstate = unsafe.Pointer(&s[0])
} }
var cptr (*C.uchar)
if len(c) > 0 {
cptr = (*C.uchar)(&c[0])
} else {
log.Error("empty config", "param name", paramWithConfigs.Param.Name)
}
o.config = c o.config = c
o.opt = C.paddle_create_optimizer( o.opt = C.paddle_create_optimizer(
(*C.uchar)(&c[0]), cptr,
C.int(len(c)), C.int(len(c)),
C.paddle_element_type(p.ElementType), C.paddle_element_type(p.ElementType),
cbuffer, cbuffer,
......
...@@ -17,12 +17,11 @@ package pserver ...@@ -17,12 +17,11 @@ package pserver
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/md5"
"encoding/gob" "encoding/gob"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"hash/crc32"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
...@@ -32,7 +31,7 @@ import ( ...@@ -32,7 +31,7 @@ import (
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -40,7 +39,7 @@ type ElementType int ...@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could // ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found. // not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found") var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
// RPC error message. // RPC error message.
const ( const (
...@@ -76,7 +75,7 @@ type ParameterWithConfig struct { ...@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type checkpointMeta struct { type checkpointMeta struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Path string `json:"path"` Path string `json:"path"`
MD5 string `json:"md5"` CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }
...@@ -92,7 +91,7 @@ type Service struct { ...@@ -92,7 +91,7 @@ type Service struct {
idx int idx int
checkpointInterval time.Duration checkpointInterval time.Duration
checkpointPath string checkpointPath string
client *EtcdClient client KVStore
mu sync.Mutex mu sync.Mutex
optMap map[string]*optimizer optMap map[string]*optimizer
...@@ -104,7 +103,12 @@ type parameterCheckpoint struct { ...@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State []byte State []byte
} }
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) { type KVStore interface {
GetKey(key string, timeout time.Duration) ([]byte, error)
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}
func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second) v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil { if err != nil {
return return
...@@ -123,7 +127,10 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) { ...@@ -123,7 +127,10 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
} }
// LoadCheckpoint loads checkpoint from file. // LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")
cpMeta, err := loadMeta(e, idx) cpMeta, err := loadMeta(e, idx)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -134,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { ...@@ -134,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return nil, err return nil, err
} }
// TODO(helin): change MD5 to CRC since CRC is better for file crc32 := crc32.ChecksumIEEE(content)
// checksum in our use case (emphasize speed over security). if crc32 != cpMeta.CRC32 {
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
return nil, errors.New(WrongChecksum) return nil, errors.New(WrongChecksum)
} }
...@@ -147,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { ...@@ -147,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if err = dec.Decode(&cp); err != nil { if err = dec.Decode(&cp); err != nil {
return nil, err return nil, err
} }
return cp, nil return cp, nil
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: interval, checkpointInterval: interval,
...@@ -170,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient ...@@ -170,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
} }
s.optMap[p.Param.Name] = newOptimizer(p, item.State) s.optMap[p.Param.Name] = newOptimizer(p, item.State)
} }
close(s.initialized)
} }
return s, nil return s, nil
} }
...@@ -178,6 +184,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient ...@@ -178,6 +184,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
log.Warn("init param called but parameters already initialized.")
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -191,6 +198,13 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error ...@@ -191,6 +198,13 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
log.Info(
"init parameter",
"name", paramWithConfigs.Param.Name,
"config len", len(paramWithConfigs.Config),
"param len", len(paramWithConfigs.Param.Content),
"type", paramWithConfigs.Param.ElementType,
)
return nil return nil
} }
...@@ -199,6 +213,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error ...@@ -199,6 +213,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error
func (s *Service) FinishInitParams(_ int, _ *int) error { func (s *Service) FinishInitParams(_ int, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
log.Warn("finished init param called but parameters already initialized.")
return errors.New(AlreadyInitialized) return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -209,10 +224,12 @@ func (s *Service) FinishInitParams(_ int, _ *int) error { ...@@ -209,10 +224,12 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t { for range t {
err := s.checkpoint() err := s.checkpoint()
if err != nil { if err != nil {
log.Errorln(err) log.Error("checkpoint error", log.Ctx{"error": err})
} }
} }
}() }()
log.Info("init parameter finished.")
return nil return nil
} }
...@@ -222,6 +239,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error { ...@@ -222,6 +239,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
log.Warn("received gradient before initialization.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
return errors.New(Uninitialized) return errors.New(Uninitialized)
} }
...@@ -233,6 +251,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error { ...@@ -233,6 +251,7 @@ func (s *Service) SendGrad(g Gradient, _ *int) error {
return fmt.Errorf("parameter: %s does not exist", g.Name) return fmt.Errorf("parameter: %s does not exist", g.Name)
} }
log.Info("received gradient from trainer, updating gradient.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
return o.UpdateParameter(g) return o.UpdateParameter(g)
} }
...@@ -244,6 +263,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -244,6 +263,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
opt, ok := s.optMap[name] opt, ok := s.optMap[name]
if !ok { if !ok {
log.Warn("trainer wants to get a parameter that does not exist.", "name", name)
return fmt.Errorf("parameter: %s does not exist", name) return fmt.Errorf("parameter: %s does not exist", name)
} }
...@@ -257,12 +277,14 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -257,12 +277,14 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name parameter.Name = name
parameter.ElementType = opt.elementType parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights() parameter.Content = opt.GetWeights()
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil return nil
} }
func traceTime(start time.Time, name string) { func traceTime(start time.Time, name string) {
elapsed := time.Since(start) elapsed := time.Since(start)
log.Infof("%s took %v", name, elapsed) log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed})
} }
// checkpoint saves checkpoint to disk. // checkpoint saves checkpoint to disk.
...@@ -270,7 +292,7 @@ func traceTime(start time.Time, name string) { ...@@ -270,7 +292,7 @@ func traceTime(start time.Time, name string) {
// checkpoint should be only called after the parameters are // checkpoint should be only called after the parameters are
// initialized. // initialized.
func (s *Service) checkpoint() (err error) { func (s *Service) checkpoint() (err error) {
log.Infoln("Begin save checkpoint.") log.Info("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint") defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock() s.mu.Lock()
...@@ -315,7 +337,7 @@ func (s *Service) checkpoint() (err error) { ...@@ -315,7 +337,7 @@ func (s *Service) checkpoint() (err error) {
closeErr := f.Close() closeErr := f.Close()
if closeErr != nil { if closeErr != nil {
if err != nil { if err != nil {
log.Errorln(closeErr) log.Error("error close checkpoint file", log.Ctx{"error": closeErr})
} else { } else {
// Set closeErr as return value. // Set closeErr as return value.
err = closeErr err = closeErr
...@@ -336,20 +358,29 @@ func (s *Service) checkpoint() (err error) { ...@@ -336,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta, err := loadMeta(s.client, s.idx) oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound { if err == ErrCheckpointNotFound {
log.Infoln("Do not have existing checkpoint.") log.Info("old meta not found, skip removing old meta")
err = nil err = nil
} else if err == nil {
log.Info("removing old meta")
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
} }
if err != nil { if err != nil {
return return
} }
h := md5.New() crc32 := crc32.ChecksumIEEE(buf.Bytes())
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
cpMeta := checkpointMeta{ cpMeta := checkpointMeta{
UUID: id, UUID: id,
Timestamp: time.Now().UnixNano(), Timestamp: time.Now().UnixNano(),
MD5: md5, CRC32: crc32,
Path: p, Path: p,
} }
...@@ -363,14 +394,5 @@ func (s *Service) checkpoint() (err error) { ...@@ -363,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return return
} }
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Errorln(rmErr)
}
}
return return
} }
package pserver
import (
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const testDir = "./test_data"
type myKV struct {
m map[string][]byte
}
func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
if m.m == nil {
m.m = make(map[string][]byte)
}
return m.m[key], nil
}
func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
if m.m == nil {
m.m = make(map[string][]byte)
}
m.m[key] = value
return nil
}
func TestCheckpoint(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
err = s.checkpoint()
assert.Nil(t, err)
_, err = LoadCheckpoint(kv, 0)
assert.Nil(t, err)
}
func float32ToByte(f float32) []byte {
var buf bytes.Buffer
err := binary.Write(&buf, binary.LittleEndian, f)
if err != nil {
fmt.Println("binary.Write failed:", err)
}
return buf.Bytes()
}
func TestCheckpointWithData(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
var content []byte
for i := 0; i < 50000; i++ {
content = append(content, float32ToByte(float32(i))...)
}
p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
assert.Nil(t, err)
err = s.FinishInitParams(0, nil)
assert.Nil(t, err)
var p2 Parameter
err = s.GetParam(p1.Name, &p2)
assert.Nil(t, err)
assert.Equal(t, p1, p2)
err = s.checkpoint()
assert.Nil(t, err)
cp, err := LoadCheckpoint(kv, 0)
assert.Nil(t, err)
s1, err := NewService(0, time.Hour, testDir, kv, cp)
assert.Nil(t, err)
var p3 Parameter
err = s1.GetParam(p1.Name, &p3)
assert.Nil(t, err)
assert.Equal(t, p1, p3)
}
...@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}
...@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters( ...@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize); modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize); is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config; paddle::TrainerConfig config;
paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) { if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
!modelConfig.IsInitialized()) {
return kPD_PROTOBUF_ERROR; return kPD_PROTOBUF_ERROR;
} }
} else {
modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine(); auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create( ptr->machine.reset(paddle::GradientMachine::create(
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE})); modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters(); std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
for (auto& para : parameters) { for (auto& para : parameters) {
para->load(is); para->load(is);
......
# ddim lib # ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(saver_proto SRCS framework.proto saver.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -10,7 +9,7 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context) ...@@ -10,7 +9,7 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor saver_proto framework_proto) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
...@@ -27,7 +26,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) ...@@ -27,7 +26,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
...@@ -43,7 +42,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -43,7 +42,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)
......
...@@ -315,6 +315,7 @@ static void CreateGradVarInBlock( ...@@ -315,6 +315,7 @@ static void CreateGradVarInBlock(
return false; /* not break */ return false; /* not break */
}); });
if (need_infer_shape) { if (need_infer_shape) {
ops[op_index]->InferVarType(block_desc);
ops[op_index]->InferShape(*block_desc); ops[op_index]->InferShape(*block_desc);
} }
} }
...@@ -452,11 +453,16 @@ ParamGradInfoMap AppendBackward( ...@@ -452,11 +453,16 @@ ParamGradInfoMap AppendBackward(
std::transform(target_shape_desc.begin(), target_shape_desc.end(), std::transform(target_shape_desc.begin(), target_shape_desc.end(),
std::back_inserter(target_shape), std::back_inserter(target_shape),
[](int64_t dim) { return static_cast<int>(dim); }); [](int64_t dim) { return static_cast<int>(dim); });
VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType();
std::unique_ptr<OpDescBind> fill_one_op( std::unique_ptr<OpDescBind> fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}}, new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", target_shape}, {{"shape", target_shape},
{"value", static_cast<float>(1.0)}, {"value", static_cast<float>(1.0)},
{"data_type", framework::DataType::FP32}})); {"data_type", target.GetDataType()}}));
// infer var type of fill_one_op
fill_one_op->InferVarType(root_block);
root_block->AppendAllocatedOp(std::move(fill_one_op)); root_block->AppendAllocatedOp(std::move(fill_one_op));
size_t forward_op_num = root_block->OpSize(); size_t forward_op_num = root_block->OpSize();
size_t forward_block_num = program_desc.Size(); size_t forward_block_num = program_desc.Size();
...@@ -475,8 +481,7 @@ ParamGradInfoMap AppendBackward( ...@@ -475,8 +481,7 @@ ParamGradInfoMap AppendBackward(
std::unordered_map<std::string, GradVarInfo> retv; std::unordered_map<std::string, GradVarInfo> retv;
auto var = root_block->Var(fill_one_op_out); auto var = root_block->Var(fill_one_op_out);
// FIXME(qiao) infer the data type var->SetDataType(target.GetDataType());
var->SetDataType(framework::DataType::FP32);
var->SetShape(target.Shape()); var->SetShape(target.Shape());
auto& target_grad = retv[target.Name()]; auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out; target_grad.name_ = fill_one_op_out;
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "paddle/framework/var_desc.h" #include "paddle/framework/var_desc.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
USE_OP(fill_constant);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() { ...@@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc));
}
for (const OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog));
}
}
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog) ProgramDescBind *prog)
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
......
...@@ -36,8 +36,7 @@ class ProgramDescBind; ...@@ -36,8 +36,7 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);
: prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog); ProgramDescBind *prog);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <typeindex> #include <typeindex>
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -28,7 +28,8 @@ enum OpInfoFillType { ...@@ -28,7 +28,8 @@ enum OpInfoFillType {
kOperator = 0, kOperator = 0,
kOpProtoAndCheckerMaker = 1, kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2, kGradOpDescMaker = 2,
kVarTypeInference = 3 kVarTypeInference = 3,
kShapeInference = 4
}; };
template <typename T> template <typename T>
...@@ -42,7 +43,10 @@ struct OpInfoFillTypeID { ...@@ -42,7 +43,10 @@ struct OpInfoFillTypeID {
? kGradOpDescMaker ? kGradOpDescMaker
: (std::is_base_of<VarTypeInference, T>::value : (std::is_base_of<VarTypeInference, T>::value
? kVarTypeInference ? kVarTypeInference
: static_cast<OpInfoFillType>(-1)))); : (std::is_base_of<InferShapeBase, T>::value
? kShapeInference
: static_cast<OpInfoFillType>(
-1)))));
} }
}; };
...@@ -121,6 +125,16 @@ struct OpInfoFiller<T, kVarTypeInference> { ...@@ -121,6 +125,16 @@ struct OpInfoFiller<T, kVarTypeInference> {
} }
}; };
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_shape_ = [](InferShapeContext* ctx) {
T inference;
inference(ctx);
};
}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <set> #include <set>
#include <vector> #include <vector>
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -56,6 +57,22 @@ Executor::~Executor() { ...@@ -56,6 +57,22 @@ Executor::~Executor() {
} }
} }
static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
if (var_type == VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == VarDesc::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == VarDesc::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else {
PADDLE_THROW(
"Variable type must be "
"LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST.");
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication) // - only runs on the first device (i.e. no interdevice communication)
...@@ -69,10 +86,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { ...@@ -69,10 +86,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
for (auto& var : block.vars()) { for (auto& var : block.vars()) {
if (var.persistable()) { if (var.persistable()) {
auto* ptr = scope->Var(var.name()); auto* ptr = scope->Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name() VLOG(3) << "Create Variable " << var.name()
<< " global, which pointer is " << ptr; << " global, which pointer is " << ptr;
} else { } else {
auto* ptr = local_scope.Var(var.name()); auto* ptr = local_scope.Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name() VLOG(3) << "Create Variable " << var.name()
<< " locally, which pointer is " << ptr; << " locally, which pointer is " << ptr;
} }
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/saver.pb.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
...@@ -106,6 +105,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const { ...@@ -106,6 +105,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const {
return lod_[level][idx + 1] - lod_[level][idx]; return lod_[level][idx + 1] - lod_[level][idx];
} }
size_t LoDTensor::NumInstancesInElement(size_t level, size_t idx) const {
PADDLE_ENFORCE_LT(level, NumLevels());
PADDLE_ENFORCE_LT(idx, NumElements(level));
auto abs_lod = ToAbsOffset(lod());
size_t begin = abs_lod[level][idx];
size_t end = abs_lod[level][idx + 1];
return end - begin;
}
void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) { void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) {
auto new_lod = framework::SliceLevels(lod_, level_begin, level_end); auto new_lod = framework::SliceLevels(lod_, level_begin, level_end);
lod_ = new_lod; lod_ = new_lod;
...@@ -117,144 +125,15 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, ...@@ -117,144 +125,15 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
PADDLE_ENFORCE_LT(elem_begin, NumElements(level)); PADDLE_ENFORCE_LT(elem_begin, NumElements(level));
PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1); PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1);
auto abs_lod = framework::ToAbsOffset(lod());
auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end); auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end);
lod_ = new_lod; lod_ = new_lod;
}
std::string LoDTensor::SerializeToString() const {
LoDTensorProto desc;
// set data_type
if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL);
if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16);
if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32);
if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64);
// FIXME(dzh): there is no fp16 in standard c++
if (this->type() == typeid(float)) // NOLINT
desc.set_data_type(DataType::FP32);
if (this->type() == typeid(double)) // NOLINT
desc.set_data_type(DataType::FP64);
for (int i = 0; i < dims().size(); ++i) {
desc.add_dims(dims()[i]);
}
// set lod information
desc.set_lod_level(this->NumLevels());
for (size_t i = 0; i < this->NumLevels(); ++i) {
LoDInfo* lod = desc.add_levels();
for (size_t j = 0; j < lod_[i].size(); ++j) {
lod->add_level(lod_[i][j]);
}
}
desc.set_version(0);
std::string desc_bytes = desc.SerializeAsString();
// FIXME(dzh) : implement fix chunk size buffer.
size_t DESC_SIZE = desc_bytes.size();
size_t DATA_SIZE = holder_->size() - offset_;
const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t);
char* buffer =
static_cast<char*>(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE));
// format: desc_size data_size, desc_bytes, data_bytes.
platform::CPUPlace src_place;
platform::CPUPlace dst_place;
memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t));
memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE,
sizeof(size_t));
memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place,
desc_bytes.c_str(), desc_bytes.size());
PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!");
platform::Place place = holder_->place(); // slice the underlying tensor
int element_width = holder_->size() / this->numel(); size_t begin = abs_lod[level][elem_begin];
size_t end = abs_lod[level][elem_end];
if (platform::is_cpu_place(place)) { PADDLE_ENFORCE_LT(begin, end, "Cannot shrink, the result tensor is empty.");
memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(), ShareDataWith(Slice(begin, end));
boost::get<platform::CPUPlace>(place),
static_cast<char*>(holder_->ptr()) + offset_ / element_width,
DATA_SIZE);
}
#ifdef PADDLE_WITH_GPU
if (platform::is_gpu_place(place)) {
memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
boost::get<platform::GPUPlace>(place),
static_cast<char*>(holder_->ptr()) + offset_ / element_width,
DATA_SIZE);
}
#endif
std::string ret(buffer, BUFFER_SIZE);
memory::Free(platform::CPUPlace(), buffer);
return ret;
} }
void LoDTensor::DeserializeFromString(const std::string& s,
const platform::Place& dst_place) {
size_t DESC_SIZE, BUFFER_SIZE;
platform::CPUPlace src_place;
memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t));
memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t),
sizeof(size_t));
const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2;
// parse LoDTensorDesc
LoDTensorProto desc;
desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE);
std::vector<int64_t> dims;
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
this->Resize(make_ddim(dims));
// parse data type
void* ptr = nullptr;
if (desc.data_type() == DataType::BOOL)
ptr = this->mutable_data<bool>(dst_place);
if (desc.data_type() == DataType::INT16)
ptr = this->mutable_data<int16_t>(dst_place);
if (desc.data_type() == DataType::INT32)
ptr = this->mutable_data<int32_t>(dst_place);
if (desc.data_type() == DataType::INT64)
ptr = this->mutable_data<int64_t>(dst_place);
// FIXME(dzh): there is no fp16 in standard c++
if (desc.data_type() == DataType::FP32)
ptr = this->mutable_data<float>(dst_place);
if (desc.data_type() == DataType::FP64)
ptr = this->mutable_data<double>(dst_place);
LoD lod;
std::vector<size_t> levels;
for (int i = 0; i < desc.levels().size(); ++i) {
auto current_level = desc.levels()[i].level();
std::copy(current_level.begin(), current_level.end(),
std::back_inserter(levels));
lod.emplace_back(levels);
levels.clear();
}
this->set_lod(lod);
if (platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), ptr, src_place,
s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
}
#ifdef PADDLE_WITH_GPU
if (platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), ptr, src_place,
s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
}
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -85,7 +85,9 @@ class LoDTensor : public Tensor { ...@@ -85,7 +85,9 @@ class LoDTensor : public Tensor {
void set_lod(const LoD& lod) { lod_ = lod; } void set_lod(const LoD& lod) { lod_ = lod; }
LoD lod() const { return lod_; } const LoD& lod() const { return lod_; }
LoD* mutable_lod() { return &lod_; }
/* /*
* Get the start offset and end offset of an element from LoD. * Get the start offset and end offset of an element from LoD.
...@@ -122,6 +124,12 @@ class LoDTensor : public Tensor { ...@@ -122,6 +124,12 @@ class LoDTensor : public Tensor {
*/ */
size_t NumElements(size_t level, size_t idx) const; size_t NumElements(size_t level, size_t idx) const;
/*
* Get the number of instances in the underlying tensor in the `idx`-th
* element.
*/
size_t NumInstancesInElement(size_t level, size_t idx) const;
/* /*
* Shrink levels[level_begin:level_end] * Shrink levels[level_begin:level_end]
*/ */
...@@ -133,29 +141,45 @@ class LoDTensor : public Tensor { ...@@ -133,29 +141,45 @@ class LoDTensor : public Tensor {
*/ */
void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end); void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end);
/**
* @brief Serialize tensor to char bytes.
* Please check model_format.md for the format detail.
* NOTE: GPUTensor will copy data to cpu implicitly.
* @return return string
*/
// FIXME(dzh) : Currently, this interface should only be used in
// save/restore model and checkpoint. ParameterServer do not use shape
// information to do the optimization, as a result, when we serialize
// parameter/gradient to string, we should serialize the tensor
// to string in the ps trainer instead of LoDTensor.
std::string SerializeToString() const;
/**
* @brief Deserialize char bytes to tensor.
* @return return string
*/
void DeserializeFromString(const std::string& s,
const platform::Place& dst_place);
private: private:
LoD lod_; LoD lod_;
}; };
/*
* Expand the `source` to fit the LoD of `lod`. For example, a `source`
* LoDTensor is
* - LoD: [0, 2]
* - tensor: [a0, a1]
* a `lod` is
* - LoD: [0 3 5]
* returns a new LoDTensor
* - [a0 a0 a0 a1 a1]
*/
template <typename T>
LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
const platform::Place& place) {
LoD abs_lod = ToAbsOffset(lod);
const auto& lod_level = lod[level];
size_t num_instances = source.dims()[0];
// new tensor
LoDTensor tensor;
tensor.set_lod(lod);
auto dims = source.dims();
dims[0] = lod_level.back();
tensor.Resize(dims);
tensor.mutable_data<T>(place);
PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
tensor.Slice(elem, elem + 1)
.CopyFrom(source.Slice(ins, ins + 1), platform::CPUPlace(),
platform::CPUDeviceContext());
}
}
return tensor;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -92,11 +92,14 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { ...@@ -92,11 +92,14 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
size_t level = 0; size_t level = 0;
LoDTensor new_lod_tensor = lod_tensor_; LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.ShrinkInLevel(level, 0, 1); new_lod_tensor.ShrinkInLevel(level, 0, 1);
EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL); ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
EXPECT_EQ(new_lod_tensor.NumElements(0), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(2), 5UL); ASSERT_EQ(new_lod_tensor.NumElements(2), 5UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>()); ASSERT_EQ(new_lod_tensor.dims()[0], 12);
for (int i = 0; i < 12 * 128; i++) {
ASSERT_EQ(new_lod_tensor.data<float>()[i], i);
}
level = 1; level = 1;
new_lod_tensor = lod_tensor_; new_lod_tensor = lod_tensor_;
...@@ -104,23 +107,41 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { ...@@ -104,23 +107,41 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 3UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 3UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>()); ASSERT_EQ(new_lod_tensor.dims()[0], 7);
for (int i = 5 * 128; i < 12 * 128; i++) {
ASSERT_EQ(new_lod_tensor.data<float>()[i - 5 * 128], i);
}
LoDTensor t1;
t1.set_lod(lod_tensor_.lod());
t1.ShareDataWith(lod_tensor_);
LoDTensor t2;
t2.set_lod(lod_tensor_.lod());
t2.ShareDataWith(lod_tensor_);
t1.ShrinkInLevel(0, 1, 2);
t2.ShrinkInLevel(0, 0, 1);
EXPECT_NE(t1.data<float>(), t2.data<float>());
EXPECT_NE(t1.data<float>(), lod_tensor_.data<float>());
} }
TEST_F(LoDTensorTester, SerializeDeserialize) { TEST(LodExpand, test) {
LoDTensor new_lod_tensor = lod_tensor_; LoD lod{{0, 2}};
float* src_ptr = lod_tensor_.data<float>(); LoDTensor tensor;
std::string s = lod_tensor_.SerializeToString(); tensor.set_lod(lod);
LoDTensor dst; tensor.Resize({2, 1});
dst.DeserializeFromString(s, platform::CPUPlace()); tensor.mutable_data<float>(platform::CPUPlace());
float* dst_ptr = dst.data<float>(); tensor.data<float>()[0] = 0;
for (int i = 0; i < kLodTensorSize; ++i) { tensor.data<float>()[1] = 1;
EXPECT_EQ(dst_ptr[i], src_ptr[i]);
LoD target;
target.emplace_back(std::vector<size_t>{0, 3, 5});
auto new_tensor = LodExpand<float>(tensor, target, 0UL, platform::CPUPlace());
std::vector<int> result{{0, 0, 0, 1, 1}};
for (size_t i = 0; i < 5; i++) {
ASSERT_EQ(new_tensor.data<float>()[i], result[i]);
} }
ASSERT_EQ(dst.NumElements(0), 2UL);
ASSERT_EQ(dst.NumElements(1), 3UL);
ASSERT_EQ(dst.NumElements(2), 8UL);
} }
} // namespace framework } // namespace framework
......
...@@ -48,30 +48,3 @@ TEST(LoDTensor, LoDInGPU) { ...@@ -48,30 +48,3 @@ TEST(LoDTensor, LoDInGPU) {
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
} }
} }
TEST(LoDTensor, SerializeDeserialize) {
paddle::framework::LoDTensor lod_tensor;
paddle::platform::GPUPlace place(0);
paddle::framework::LoD src_lod;
src_lod.push_back(std::vector<size_t>{0, 2, 4, 6, 8, 10, 12, 14});
lod_tensor.Resize({14, 16});
lod_tensor.mutable_data<float>(place);
lod_tensor.set_lod(src_lod);
CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
test<<<1, 8>>>(src_lod[0].data(), src_lod[0].size());
cudaDeviceSynchronize();
std::string s = lod_tensor.SerializeToString();
paddle::framework::LoDTensor dst;
dst.DeserializeFromString(s, place);
paddle::framework::LoD dst_lod = dst.lod();
for (size_t i = 0; i < dst_lod[0].size(); ++i) {
CHECK_EQ(src_lod[0].data()[i], dst_lod[0].data()[i] * 2);
}
}
...@@ -14,9 +14,13 @@ limitations under the License. */ ...@@ -14,9 +14,13 @@ limitations under the License. */
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include <functional> #include <functional>
#include <mutex>
#include <unordered_map> #include <unordered_map>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -24,16 +28,47 @@ namespace framework { ...@@ -24,16 +28,47 @@ namespace framework {
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs) { const AttributeMap &attrs) {
op_desc_.set_type(type); desc_.set_type(type);
inputs_ = inputs; inputs_ = inputs;
outputs_ = outputs; outputs_ = outputs;
attrs_ = attrs; attrs_ = attrs;
need_update_ = true; need_update_ = true;
} }
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore outputs_
int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
}
}
OpDesc *OpDescBind::Proto() { OpDesc *OpDescBind::Proto() {
Flush(); Flush();
return &op_desc_; return &desc_;
} }
const std::vector<std::string> &OpDescBind::Input( const std::vector<std::string> &OpDescBind::Input(
...@@ -167,23 +202,23 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -167,23 +202,23 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void OpDescBind::Flush() { void OpDescBind::Flush() {
if (need_update_) { if (need_update_) {
this->op_desc_.mutable_inputs()->Clear(); this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) { for (auto &ipt : inputs_) {
auto *input = op_desc_.add_inputs(); auto *input = desc_.add_inputs();
input->set_parameter(ipt.first); input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments()); VectorToRepeated(ipt.second, input->mutable_arguments());
} }
this->op_desc_.mutable_outputs()->Clear(); this->desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) { for (auto &opt : outputs_) {
auto *output = op_desc_.add_outputs(); auto *output = desc_.add_outputs();
output->set_parameter(opt.first); output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments()); VectorToRepeated(opt.second, output->mutable_arguments());
} }
this->op_desc_.mutable_attrs()->Clear(); this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) { for (auto &attr : attrs_) {
auto *attr_desc = op_desc_.add_attrs(); auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first); attr_desc->set_name(attr.first);
attr_desc->set_type( attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1)); static_cast<framework::AttrType>(attr.second.which() - 1));
...@@ -195,26 +230,26 @@ void OpDescBind::Flush() { ...@@ -195,26 +230,26 @@ void OpDescBind::Flush() {
} }
} }
using InferShapeFuncMap = static std::once_flag init_infer_shape_funcs;
std::unordered_map<std::string /*op_type*/,
std::function<void(InferShapeContext *)>>; static void InitInferShapeFuncs() {
std::call_once(init_infer_shape_funcs, [] {
static InferShapeFuncMap &InferShapeFuncs() { auto &map = OpInfoMap::Instance();
static InferShapeFuncMap *g_map = nullptr; auto &info_map = *map.mutable_map();
if (g_map == nullptr) {
g_map = new InferShapeFuncMap(); for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
auto &info_map = OpInfoMap::Instance(); auto op_type = kern_pair.first;
// all registered kernels auto &op_info = info_map.at(op_type);
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
auto &info = info_map.Get(pair.first);
// use empty type here to avoid runtime checks.
auto op = auto op =
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {})); static_cast<OperatorWithKernel *>(op_info.Creator()("", {}, {}, {}));
g_map->insert( if (op_info.infer_shape_) { // infer_shape has been registered.
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }}); continue;
} }
op_info.infer_shape_ = [op](InferShapeContext *ctx) {
op->InferShape(ctx);
};
} }
return *g_map; });
} }
void OpDescBind::CheckAttrs() { void OpDescBind::CheckAttrs() {
...@@ -230,13 +265,13 @@ void OpDescBind::CheckAttrs() { ...@@ -230,13 +265,13 @@ void OpDescBind::CheckAttrs() {
} }
void OpDescBind::InferShape(const BlockDescBind &block) const { void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs(); VLOG(3) << "CompileTime infer shape on " << Type();
auto it = funcs.find(this->Type()); InitInferShapeFuncs();
if (it == funcs.end()) { auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
PADDLE_THROW("Operator %s has not been registered", this->Type()); PADDLE_ENFORCE(static_cast<bool>(infer_shape),
} "%s's infer_shape has not been registered", this->Type());
CompileTimeInferShapeContext ctx(*this, block); CompileTimeInferShapeContext ctx(*this, block);
it->second(&ctx); infer_shape(&ctx);
} }
void OpDescBind::InferVarType(BlockDescBind *block) const { void OpDescBind::InferVarType(BlockDescBind *block) const {
......
...@@ -24,6 +24,7 @@ namespace paddle { ...@@ -24,6 +24,7 @@ namespace paddle {
namespace framework { namespace framework {
class BlockDescBind; class BlockDescBind;
class ProgramDescBind;
class OpDescBind { class OpDescBind {
public: public:
...@@ -32,11 +33,13 @@ class OpDescBind { ...@@ -32,11 +33,13 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
OpDesc *Proto(); OpDesc *Proto();
std::string Type() const { return op_desc_.type(); } std::string Type() const { return desc_.type(); }
void SetType(const std::string &type) { op_desc_.set_type(type); } void SetType(const std::string &type) { desc_.set_type(type); }
const std::vector<std::string> &Input(const std::string &name) const; const std::vector<std::string> &Input(const std::string &name) const;
...@@ -117,7 +120,7 @@ class OpDescBind { ...@@ -117,7 +120,7 @@ class OpDescBind {
return ret_val; return ret_val;
} }
OpDesc op_desc_; OpDesc desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
......
...@@ -25,12 +25,19 @@ ...@@ -25,12 +25,19 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeBase {
public:
virtual ~InferShapeBase() = default;
virtual void operator()(InferShapeContext*) const = 0;
};
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_; InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_;
bool HasOpProtoAndChecker() const { bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
...@@ -87,13 +94,13 @@ class OpInfoMap { ...@@ -87,13 +94,13 @@ class OpInfoMap {
} }
} }
const std::unordered_map<std::string, const OpInfo>& map() const { const std::unordered_map<std::string, OpInfo>& map() const { return map_; }
return map_;
} std::unordered_map<std::string, OpInfo>* mutable_map() { return &map_; }
private: private:
OpInfoMap() = default; OpInfoMap() = default;
std::unordered_map<std::string, const OpInfo> map_; std::unordered_map<std::string, OpInfo> map_;
DISABLE_COPY_AND_ASSIGN(OpInfoMap); DISABLE_COPY_AND_ASSIGN(OpInfoMap);
}; };
......
...@@ -33,24 +33,6 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -33,24 +33,6 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
} }
#endif #endif
const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return var->GetMutable<Tensor>();
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -204,6 +186,30 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -204,6 +186,30 @@ void OperatorBase::GenerateTemporaryNames() {
} }
} }
static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
t = &(var->Get<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
return t;
}
static Tensor* GetMutableTensorFromVar(Variable* var) {
Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
t = var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
return t;
}
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
...@@ -227,7 +233,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -227,7 +233,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const { Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name); auto var = OutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<LoDTensor>(); return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
} }
template <> template <>
...@@ -240,7 +246,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( ...@@ -240,7 +246,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr return var == nullptr ? nullptr
: var->GetMutable<LoDTensor>(); : GetMutableTensorFromVar(var);
}); });
return res; return res;
} }
......
...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/shape_inference.h" #include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
...@@ -60,9 +61,6 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -60,9 +61,6 @@ inline std::string GradVarName(const std::string& var_name) {
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -414,7 +412,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -414,7 +412,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
private: private:
DDim GetDim(const std::string& name) const override { DDim GetDim(const std::string& name) const override {
return framework::make_ddim(block_.FindVarRecursive(name)->Shape()); auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
return framework::make_ddim(var->Shape());
} }
void SetDim(const std::string& name, const DDim& dim) override { void SetDim(const std::string& name, const DDim& dim) override {
...@@ -511,28 +511,26 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -511,28 +511,26 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
private: private:
template <bool Allocate> DDim GetDim(const std::string& name) const override {
Tensor* GetTensor(const std::string& name) const { Variable* var = scope_.FindVar(name);
Tensor* t = nullptr; if (var->IsType<LoDTensor>()) {
auto* var = scope_.FindVar(name); return var->Get<LoDTensor>().dims();
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) { } else if (var->IsType<SelectedRows>()) {
if (Allocate) { return var->Get<SelectedRows>().GetCompleteDims();
t = var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable(%s) should be tensor", name);
}
} else { } else {
t = GetTensorFromVar(scope_.FindVar(name)); PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
return t;
} }
DDim GetDim(const std::string& name) const override {
return GetTensor<false>(name)->dims();
} }
void SetDim(const std::string& name, const DDim& dim) override { void SetDim(const std::string& name, const DDim& dim) override {
GetTensor<true>(name)->Resize(dim); Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
} }
const OperatorBase& op_; const OperatorBase& op_;
...@@ -638,7 +636,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -638,7 +636,9 @@ class OperatorWithKernel : public OperatorBase {
}); });
} }
virtual void InferShape(InferShapeContext* ctx) const = 0; virtual void InferShape(InferShapeContext* ctx) const {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
}
protected: protected:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
...@@ -655,11 +655,14 @@ class OperatorWithKernel : public OperatorBase { ...@@ -655,11 +655,14 @@ class OperatorWithKernel : public OperatorBase {
t = &var->Get<Tensor>(); t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) { } else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>(); t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(ToDataType(t->type()));
VLOG(3) << "Input " << ipt_name << " with data_type " << tmp;
PADDLE_ENFORCE(tmp == data_type || data_type == -1, PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op must be same."); "DataType of Paddle Op %s must be same.", Type());
data_type = tmp; data_type = tmp;
} }
} }
......
...@@ -237,12 +237,12 @@ TEST(OpKernel, multi_inputs) { ...@@ -237,12 +237,12 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
scope.Var("x0")->GetMutable<Tensor>(); scope.Var("x0")->GetMutable<LoDTensor>();
scope.Var("x1")->GetMutable<Tensor>(); scope.Var("x1")->GetMutable<LoDTensor>();
scope.Var("x2")->GetMutable<Tensor>(); scope.Var("x2")->GetMutable<LoDTensor>();
scope.Var("k0")->GetMutable<Tensor>(); scope.Var("k0")->GetMutable<LoDTensor>();
scope.Var("y0")->GetMutable<Tensor>(); scope.Var("y0")->GetMutable<LoDTensor>();
scope.Var("y1")->GetMutable<Tensor>(); scope.Var("y1")->GetMutable<LoDTensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
......
...@@ -19,9 +19,9 @@ namespace paddle { ...@@ -19,9 +19,9 @@ namespace paddle {
namespace framework { namespace framework {
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_.add_blocks(); auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID()); b->set_parent_idx(parent.ID());
b->set_idx(prog_.blocks_size() - 1); b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b)); blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get(); return blocks_.back().get();
} }
...@@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() { ...@@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
return &prog_; return &desc_;
} }
ProgramDescBind::ProgramDescBind() { ProgramDescBind::ProgramDescBind() {
auto *block = prog_.mutable_blocks()->Add(); auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex); block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex); block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block)); blocks_.emplace_back(new BlockDescBind(this, block));
} }
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
prog_ = o.prog_; desc_ = o.desc_;
for (int i = 0; i < prog_.blocks_size(); ++i) { for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = prog_.mutable_blocks(i); auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this)); blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
} }
} }
ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,8 @@ class ProgramDescBind { ...@@ -31,6 +31,8 @@ class ProgramDescBind {
ProgramDescBind(const ProgramDescBind &o); ProgramDescBind(const ProgramDescBind &o);
explicit ProgramDescBind(const std::string &binary_str);
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
...@@ -40,7 +42,7 @@ class ProgramDescBind { ...@@ -40,7 +42,7 @@ class ProgramDescBind {
ProgramDesc *Proto(); ProgramDesc *Proto();
private: private:
ProgramDesc prog_; ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDescBind>> blocks_;
}; };
......
...@@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) {
}; };
ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames()); ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
ASSERT_EQ(3, global_block_copy->LocalVarNames().size()); ASSERT_EQ(3UL, global_block_copy->LocalVarNames().size());
assert_same_var("X", x); assert_same_var("X", x);
assert_same_var("Y", y); assert_same_var("Y", y);
assert_same_var("Out", out); assert_same_var("Out", out);
...@@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) {
// Not check block's protostr are same it because the order of vars could be // Not check block's protostr are same it because the order of vars could be
// different and it is correct. // different and it is correct.
} }
TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin;
auto* global_block = program_origin.Block(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
op->SetType("mul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
std::string binary_str;
program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str);
auto* global_block_restored = program_restored.Block(0);
ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
ASSERT_TRUE(global_block_restored->HasVar(name));
auto* restored = global_block_restored->Var(name);
ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape());
ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
ASSERT_EQ(global_block->LocalVarNames(),
global_block_restored->LocalVarNames());
ASSERT_EQ(3UL, global_block_restored->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_restored = global_block->Op(i);
ASSERT_EQ(op_origin->Type(), op_restored->Type());
ASSERT_EQ(op_origin->Inputs(), op_restored->Inputs());
ASSERT_EQ(op_origin->Outputs(), op_restored->Outputs());
ASSERT_EQ(op_restored->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -23,7 +23,10 @@ class SelectedRows { ...@@ -23,7 +23,10 @@ class SelectedRows {
value_.reset(new Tensor()); value_.reset(new Tensor());
} }
SelectedRows() { value_.reset(new Tensor()); } SelectedRows() {
height_ = 0;
value_.reset(new Tensor());
}
platform::Place place() const { return value_->place(); } platform::Place place() const { return value_->place(); }
...@@ -37,6 +40,8 @@ class SelectedRows { ...@@ -37,6 +40,8 @@ class SelectedRows {
const Vector<int64_t>& rows() const { return rows_; } const Vector<int64_t>& rows() const { return rows_; }
Vector<int64_t>* mutable_rows() { return &rows_; }
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; } void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
DDim GetCompleteDims() const { DDim GetCompleteDims() const {
......
...@@ -132,6 +132,8 @@ class Tensor { ...@@ -132,6 +132,8 @@ class Tensor {
std::type_index type() const { return holder_->type(); } std::type_index type() const { return holder_->type(); }
size_t memory_size() const;
private: private:
inline void check_memory_size() const; inline void check_memory_size() const;
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include "paddle/framework/eigen.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -104,10 +106,10 @@ void TensorArray::Write(size_t index, const LoDTensor& value) { ...@@ -104,10 +106,10 @@ void TensorArray::Write(size_t index, const LoDTensor& value) {
values_.resize(index + 1); values_.resize(index + 1);
} }
values_[index].set_lod(value.lod());
values_[index].Resize(value.dims()); values_[index].Resize(value.dims());
values_[index].mutable_data<value_type>(platform::CPUPlace()); values_[index].mutable_data<value_type>(value.place());
values_[index].CopyFrom(value, platform::CPUPlace(), values_[index].CopyFrom(value, value.place(), platform::CPUDeviceContext());
platform::CPUDeviceContext());
} }
void TensorArray::WriteShared(size_t index, const LoDTensor& value) { void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
...@@ -116,6 +118,7 @@ void TensorArray::WriteShared(size_t index, const LoDTensor& value) { ...@@ -116,6 +118,7 @@ void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
values_.resize(index + 1); values_.resize(index + 1);
} }
values_[index].set_lod(value.lod());
values_[index].ShareDataWith(value); values_[index].ShareDataWith(value);
} }
...@@ -144,6 +147,155 @@ DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level, ...@@ -144,6 +147,155 @@ DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level,
return unpacker.meta; return unpacker.meta;
} }
LoDTensor TensorArray::LodPack(size_t level) const {
PADDLE_ENFORCE_GT(size(), 0UL, "no time step exists");
// the levels should be no less than 2
LoDTensor merged;
const LoDTensor *pre, *cur;
pre = &Read(0);
for (size_t step = 1; step < size(); step++) {
cur = &Read(step);
PADDLE_ENFORCE_GT(cur->NumLevels(), 0);
PADDLE_ENFORCE_GT(pre->NumLevels(), 0);
PADDLE_ENFORCE_EQ(pre->NumLevels(), cur->NumLevels());
PADDLE_ENFORCE_EQ(pre->NumElements(level), cur->NumElements(level));
merged = LodPackTwo(*pre, *cur, level);
pre = &merged;
}
return merged;
}
/*
* NOTE currently, only the lowest level supports packing.
* The lowest LoD will be changed, while the relative offsets in levels above
* stay unchanged.
*
* previous step : [0] [1] [3]
* current step: [0 1 2] [2 3] []
* packed to
* [0 0] [0 1] [0 2] [1 2] [1 3] [3]
*/
LoDTensor TensorArray::LodPackTwo(const LoDTensor& pre, const LoDTensor& cur,
size_t level) const {
PADDLE_ENFORCE_EQ(pre.NumLevels(), cur.NumLevels());
PADDLE_ENFORCE_EQ(pre.NumLevels(), level + 1,
"Only the lowest LoD level supports pack temporarily.");
// calculate the result tensor's shape first
size_t num_instances = 0;
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
if (num_candidates > 0) {
num_instances += num_candidates * (prefix_size + 1);
} else {
num_instances += prefix_size;
}
}
auto res_dims = pre.dims();
res_dims[0] = num_instances;
LoDTensor result;
result.Resize(res_dims);
result.mutable_data<value_type>(cur.place());
Vector<size_t> last_lod_level;
// copy data
size_t index = 0;
last_lod_level.push_back(index);
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
// slice the prefix Tensor
LoDTensor prefix = pre;
prefix.ShrinkInLevel(level, elem, elem + 1);
LoDTensor candidate = cur;
if (num_candidates > 0) {
candidate.ShrinkInLevel(level, elem, elem + 1);
} else { // just push prefix
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
last_lod_level.push_back(index);
}
for (size_t candi = 0; candi < num_candidates; candi++) {
// TODO(superjom) support GPU
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
// copy candidate record
result.Slice(index, index + 1)
.CopyFrom(candidate.Slice(candi, candi + 1), result.place(),
platform::CPUDeviceContext());
index++;
last_lod_level.push_back(index);
}
}
// update lod
auto lod = cur.lod();
lod.back() = last_lod_level;
result.set_lod(lod);
return result;
}
/*
* source [0 1 2] [3 4] [5 6 7] will be transformd to a list of LoDTensors such
* as
* [0 3 5] [1 4 6] [2 7] with 1-level LoDs:
* - [0 1 2 3]
* - [0 1 2 3]
* - [0 1 1 2], the [1,1) here means the second sequence is empty
*
* NOTE Unpack a LoDTensor in this approach may result in a big LoD.
*/
void TensorArray::LodUnpack(const LoDTensor& source, size_t level) {
PADDLE_ENFORCE_EQ(level, source.NumLevels() - 1,
"only the lowest LoD level supports unpack.");
const size_t non_empty_instances = source.dims()[0];
size_t index = 0;
Vector<size_t> lowest_lod_level;
lowest_lod_level.push_back(index);
for (size_t step = 0; step < non_empty_instances; step++) {
size_t num_instances = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
num_instances++;
index++;
}
lowest_lod_level.push_back(index);
}
// create tensor for this time step
LoDTensor tensor;
auto dims = source.dims();
dims[0] = num_instances;
// set lod
auto lod = source.lod();
lod.back() = lowest_lod_level;
tensor.set_lod(lod);
index = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
// copy this instance
tensor.Slice(index, index + 1)
.CopyFrom(instance.Slice(step, step + 1), tensor.place(),
platform::CPUDeviceContext());
index++;
}
}
Write(step, tensor);
}
}
LoDTensor TensorArray::Stack() const { LoDTensor TensorArray::Stack() const {
LoDTensor result; LoDTensor result;
if (size() == 0) return result; if (size() == 0) return result;
......
...@@ -86,6 +86,16 @@ class TensorArray { ...@@ -86,6 +86,16 @@ class TensorArray {
*/ */
DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend); DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend);
/*
* Pack an array of LoDTensors to a LoDTensor.
*/
LoDTensor LodPack(size_t level) const;
/*
* Unpack a LoDTensor to an array of LoDTensors.
*/
void LodUnpack(const LoDTensor &source, size_t level);
/* /*
* Pack the values into a tensor with rank one higher than each tensor in * Pack the values into a tensor with rank one higher than each tensor in
* values. * values.
...@@ -111,6 +121,9 @@ class TensorArray { ...@@ -111,6 +121,9 @@ class TensorArray {
protected: protected:
void Unstack(const LoDTensor &source, bool data_shared) const; void Unstack(const LoDTensor &source, bool data_shared) const;
LoDTensor LodPackTwo(const LoDTensor &pre, const LoDTensor &cur,
size_t level) const;
private: private:
mutable std::vector<LoDTensor> values_; mutable std::vector<LoDTensor> values_;
}; // class TensorArray }; // class TensorArray
......
...@@ -126,5 +126,57 @@ TEST_F(TensorArrayTester, size) { ...@@ -126,5 +126,57 @@ TEST_F(TensorArrayTester, size) {
ASSERT_EQ(ta.size(), static_cast<size_t>(batch_size)); ASSERT_EQ(ta.size(), static_cast<size_t>(batch_size));
} }
TEST(TensorArray, LodPack) {
// three time steps, each step stores a LoDTensors
// - [0] [1]
// - [2 3], [4 5]
// - [6 7] [] [8], [9, 10]
// try to get a LoDTensor with content:
// - [0 2 6]
// - [0 2 7]
// - [0 3]
// - [1 4 8]
// - [1 5 9]
// - [1 5 10]
std::array<LoDTensor, 3> tensors;
tensors[0].Resize(make_ddim({2, 1}));
tensors[1].Resize(make_ddim({4, 1}));
tensors[2].Resize(make_ddim({5, 1}));
int index = 0;
for (auto& t : tensors) {
t.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < t.dims()[0]; i++) {
t.data<int>()[i] = index;
index++;
}
}
std::array<LoD, 3> lods;
std::vector<std::vector<size_t>> levels{
{0, 1, 2}, {0, 2, 4}, {0, 2, 2, 3, 5}};
for (int i = 0; i < 3; i++) {
lods[i].emplace_back(levels[i].begin(), levels[i].end());
}
TensorArray ta;
for (int i = 0; i < 3; i++) {
tensors[i].set_lod(lods[i]);
ta.Write(i, tensors[i]);
}
auto merged = ta.LodPack(0);
std::vector<int> target_tensor_data{{0, 2, 6, // 0
0, 2, 7, // 1
0, 3, // 2
1, 4, 8, // 3
1, 5, 9, // 5
1, 5, 10}};
EXPECT_EQ(merged.dims()[0], (int)target_tensor_data.size());
for (size_t i = 0; i < target_tensor_data.size(); i++) {
EXPECT_EQ(target_tensor_data[i], merged.data<int>()[i]);
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -62,12 +62,16 @@ inline void Tensor::check_memory_size() const { ...@@ -62,12 +62,16 @@ inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
holder_->size(), numel() * SizeOfType(type()) + offset_, holder_->size(), memory_size() + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n" "first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored."); "or maybe the required data-type mismatches the data already stored.");
} }
inline size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : numel() * SizeOfType(type());
}
template <typename T> template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
......
...@@ -28,6 +28,8 @@ class OperatorBase; ...@@ -28,6 +28,8 @@ class OperatorBase;
class OpDescBind; class OpDescBind;
class BlockDescBind; class BlockDescBind;
class BlockDesc; class BlockDesc;
class InferShapeContext;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto // The order should be as same as framework.proto
...@@ -49,5 +51,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>( ...@@ -49,5 +51,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
using InferVarTypeFN = std::function<void(const OpDescBind& /*op_desc*/, using InferVarTypeFN = std::function<void(const OpDescBind& /*op_desc*/,
BlockDescBind* /*block*/)>; BlockDescBind* /*block*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,6 +59,8 @@ class VarDescBind { ...@@ -59,6 +59,8 @@ class VarDescBind {
desc_.set_type(VarDesc::LOD_TENSOR); desc_.set_type(VarDesc::LOD_TENSOR);
} }
explicit VarDescBind(const VarDesc &desc) : desc_(desc) {}
VarDesc *Proto() { return &desc_; } VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
......
...@@ -46,6 +46,8 @@ class Variable { ...@@ -46,6 +46,8 @@ class Variable {
std::type_index(typeid(T)) == std::type_index(holder_->Type()); std::type_index(typeid(T)) == std::type_index(holder_->Type());
} }
void Clear() { holder_.reset(); }
private: private:
struct Placeholder { struct Placeholder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MKLDNNBatchNormLayer.h"
using namespace mkldnn; // NOLINT
typedef memory::format format;
namespace paddle {
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
const real MKLDNNBatchNormLayer::EPS = 1E-5;
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
return false;
}
// first one is input layer
// the other two are created in config_parser.py saving moving mean and var
CHECK_EQ(inputLayers_.size(), 3U);
CHECK_EQ(inputLayers_.size(), parameters_.size());
CHECK_EQ(inputLayers_.size(), size_t(config_.inputs_size()));
const ImageConfig& conf = config_.inputs(0).image_conf();
ic_ = conf.channels();
ih_ = inputLayers_[0]->getOutput().getFrameHeight();
iw_ = inputLayers_[0]->getOutput().getFrameWidth();
if (iw_ == 0 && ih_ == 0) {
iw_ = conf.img_size();
ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
}
oc_ = ic_;
oh_ = ih_;
ow_ = iw_;
if (config_.has_use_global_stats()) {
useGlobalStats_ = config_.use_global_stats();
}
movingAvgFraction_ = config_.moving_average_fraction();
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
<< " --- global stats";
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
initWeight();
movingMean_.reset(new Weight(oc_, 1, parameters_[1], 0));
movingVar_.reset(new Weight(oc_, 1, parameters_[2], 0));
return true;
}
void MKLDNNBatchNormLayer::initWeight() {
weight_.reset(new Weight(1, oc_, parameters_[0]));
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
}
CHECK_EQ(weight_ != nullptr, biases_ != nullptr)
<< "only support have both weight and bias, or neither";
if (weight_ && weight_->getW()) {
CHECK(biases_ && biases_->getW());
valueScaleShift_ = Matrix::create(2, oc_, false, false);
valueScaleShift_->zeroMem();
VectorPtr scale(new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), 0));
VectorPtr shift(
new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), oc_));
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_VALUE);
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_VALUE);
scale->copyFrom(*wgt);
shift->copyFrom(*bias);
wgt->setData(valueScaleShift_->getData());
bias->setData(valueScaleShift_->getData() + oc_);
}
if (weight_ && weight_->getWGrad()) {
CHECK(biases_ && biases_->getWGrad());
gradScaleShift_ = Matrix::create(2, oc_, false, false);
gradScaleShift_->zeroMem();
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_GRADIENT);
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_GRADIENT);
wgt->setData(gradScaleShift_->getData());
bias->setData(gradScaleShift_->getData() + oc_);
}
}
void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
if (hasInitedWgt_) {
return;
}
// prepare mean and var if necessary
if (useGlobalStats_) {
CHECK(mean_);
CHECK(var_);
mean_->copyFrom(*(movingMean_->getW()));
var_->copyFrom(*(movingVar_->getW()));
}
hasInitedWgt_ = true;
}
void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
// calculating and saving moving mean and variance
CHECK_EQ(useGlobalStats_, false);
movingMean_->getW()->add(
*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
// here var is v^2
movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
}
void MKLDNNBatchNormLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw);
oh = ih;
ow = ow;
// ic_ and oc can not be changed
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
<< "Input channel can not be changed";
reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow);
printSizeInfo();
}
void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
// In training phase, it will always calculate mean and var,
// so useGlobalStats must be false.
// In scoring phase, it depends on useGlobalStats choice.
if (passType_ != PASS_TEST && useGlobalStats_ == true) {
LOG(WARNING) << "use_global_stats is invalid setting in training phase";
useGlobalStats_ = false;
}
resetFwdBuffers(in, wgt, out);
resetFwdPD(fwdPD_, in, wgt, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, out);
}
void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
std::shared_ptr<bn_bwd::primitive_desc> pd;
resetBwdBuffers(in, wgt, out);
resetBwdPD(pd, in, wgt, out);
resetBwdPipeline(pipeline, pd, in, wgt, out);
}
void MKLDNNBatchNormLayer::forward(PassType passType) {
MKLDNNLayer::forward(passType);
// calculate and save moving mean and variance
if (passType_ != PASS_TEST) {
calMovingMeanAndVar();
}
}
void MKLDNNBatchNormLayer::updateWeights(const UpdateCallback& callback) {
weight_->getParameterPtr()->incUpdate(callback);
if (biases_ && biases_->getWGrad()) {
biases_->getParameterPtr()->incUpdate(callback);
}
}
void MKLDNNBatchNormLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
resetInValue(in);
memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_};
CHECK(in);
auto outPD =
MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_);
resetOutValue(out, outPD);
if (valueScaleShift_) {
auto pd = MKLDNNMatrix::createPrimitiveDesc({2, oc_}, format::nc, engine_);
resetWithMatrix(wgt, valueScaleShift_, pd);
}
if (passType_ != PASS_TEST || useGlobalStats_) {
auto pd = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
mean_ = MKLDNNMatrix::create(pd);
var_ = MKLDNNMatrix::create(pd);
}
}
void MKLDNNBatchNormLayer::resetFwdPD(
std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in,
MKLDNNMatrixPtr wgt,
MKLDNNMatrixPtr out) {
flags_ = 0u;
prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring
: prop_kind::forward_training;
if (useGlobalStats_) {
flags_ = (flags_ | batch_normalization_flag::use_global_stats);
}
if (wgt) {
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
}
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
// TODO(TJ): use check macro
CHECK(out);
CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc());
if (wgt) {
CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc());
}
if (passType_ != PASS_TEST || useGlobalStats_) {
CHECK(mean_);
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
CHECK(var_);
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
}
}
void MKLDNNBatchNormLayer::resetFwdPipeline(
std::vector<primitive>& pipeline,
std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
if (passType_ == PASS_TEST) {
if (useGlobalStats_) {
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd,
*in,
(const primitive::at)(*mean_),
(const primitive::at)(*var_),
*wgt,
*out)
: new bn_fwd(*pd,
*in,
(const primitive::at)(*mean_),
(const primitive::at)(*var_),
*out));
} else {
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out)
: new bn_fwd(*pd, *in, *out));
}
} else {
CHECK_EQ(useGlobalStats_, false)
<< "useGlobalStats should be false in training";
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out, *mean_, *var_)
: new bn_fwd(*pd, *in, *out, *mean_, *var_));
}
pipeline.push_back(*fwd_);
}
void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
CHECK(inVal_ && outVal_);
resetOutGrad(out, outVal_->getPrimitiveDesc());
resetInGrad(in, inVal_->getPrimitiveDesc());
if (gradScaleShift_) {
CHECK(wgtVal_);
resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc());
}
}
void MKLDNNBatchNormLayer::resetBwdPD(
std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
pd = nullptr;
if (in == nullptr) {
return;
}
CHECK(out);
CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc());
auto md = in->getMemoryDesc();
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
// TODO(TJ): use check macro
CHECK(wgt);
CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc());
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
CHECK(mean_);
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
CHECK(var_);
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
}
void MKLDNNBatchNormLayer::resetBwdPipeline(
std::vector<primitive>& pipeline,
std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
if (pd == nullptr) {
return;
}
CHECK(inVal_);
bwdData_.reset(
wgt && wgtVal_
? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt)
: new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in));
pipeline.push_back(*bwdData_);
}
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "MKLDNNLayer.h"
#include "mkldnn.hpp"
namespace paddle {
typedef mkldnn::batch_normalization_forward bn_fwd;
typedef mkldnn::batch_normalization_backward bn_bwd;
/**
* @brief A subclass of MKLDNNLayer BatchNorm layer.
*
* The config file api is mkldnn_batch_norm
*/
class MKLDNNBatchNormLayer : public MKLDNNLayer {
protected:
// save forward primitive_desc, which can be used backward
std::shared_ptr<bn_fwd::primitive_desc> fwdPD_;
// Epsilon value used in the batch normalization formula.
static const real EPS;
// weight and bias in paddle
std::unique_ptr<Weight> weight_;
std::unique_ptr<Weight> biases_;
// mkldnn use a large buffer store both scale and shift
// which are weight and bias in paddle corresponding.
MatrixPtr valueScaleShift_;
MatrixPtr gradScaleShift_;
// Moving average of mean.
std::unique_ptr<Weight> movingMean_;
// Moving average of variance.
std::unique_ptr<Weight> movingVar_;
// if useGlobalStats_ is true, will use the loaded mean and variance.
// otherwise, calculate mean and variance in every mini-batch.
bool useGlobalStats_;
// used in MKLDNN primitive desc
unsigned flags_;
// use to compute moving mean and variance.
real movingAvgFraction_;
// whether the weight has been init
bool hasInitedWgt_;
// local mean and variance
// when useGlobalStats_ they are loaded from moving mean and variance
// when do not useGlobalStats_ they are calculated from this mini-batch
MKLDNNMatrixPtr mean_;
MKLDNNMatrixPtr var_;
public:
explicit MKLDNNBatchNormLayer(const LayerConfig& config)
: MKLDNNLayer(config), useGlobalStats_(true), hasInitedWgt_(false) {}
~MKLDNNBatchNormLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override;
void convertWeightsFromPaddle() override;
protected:
void initWeight();
/**
* cal moving mean and variance.
* moving = moving * AvgFraction + local * (1 - AvgFraction)
*/
void calMovingMeanAndVar();
/**
* Forward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in,
MKLDNNMatrixPtr wgt,
MKLDNNMatrixPtr out);
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
void resetBwdPD(std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
};
} // namespace paddle
...@@ -91,10 +91,16 @@ void MKLDNNTester::setInputImgSize() { ...@@ -91,10 +91,16 @@ void MKLDNNTester::setInputImgSize() {
// init randome parameters of ref, and copy to mkldnn // init randome parameters of ref, and copy to mkldnn
void MKLDNNTester::randomWgtDatas() { void MKLDNNTester::randomWgtDatas() {
EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size()); EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size());
const bool isBN = refLayer_->getType() == "batch_norm";
for (size_t i = 0; i < parameters_[REF].size(); ++i) { for (size_t i = 0; i < parameters_[REF].size(); ++i) {
const VectorPtr& dnnValue = parameters_[DNN][i]->getBuf(PARAMETER_VALUE); const VectorPtr& dnnValue = parameters_[DNN][i]->getBuf(PARAMETER_VALUE);
const VectorPtr& refValue = parameters_[REF][i]->getBuf(PARAMETER_VALUE); const VectorPtr& refValue = parameters_[REF][i]->getBuf(PARAMETER_VALUE);
parameters_[REF][i]->randomize(); parameters_[REF][i]->randomize();
if (isBN && i == 2) {
// this param is moving average in batch norm, which must larger than 0
real offset = fabs(refValue->getMin()) + 1.0;
refValue->add(offset);
}
dnnValue->copyFrom(*refValue); dnnValue->copyFrom(*refValue);
VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName(); VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName();
...@@ -132,8 +138,7 @@ void MKLDNNTester::checkForward() { ...@@ -132,8 +138,7 @@ void MKLDNNTester::checkForward() {
void MKLDNNTester::checkBackwardData() { void MKLDNNTester::checkBackwardData() {
VLOG(MKLDNN_TESTS) << "Check Backward Data"; VLOG(MKLDNN_TESTS) << "Check Backward Data";
// TODO(TJ): uncomment me when batch norm ready const bool isBN = refLayer_->getType() == "batch_norm";
// const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm";
for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) { for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) {
const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad(); const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad();
const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad(); const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad();
...@@ -144,11 +149,11 @@ void MKLDNNTester::checkBackwardData() { ...@@ -144,11 +149,11 @@ void MKLDNNTester::checkBackwardData() {
double delta = compareMatrix(dnnDiff, refDiff); double delta = compareMatrix(dnnDiff, refDiff);
EXPECT_LE(fabs(delta), eps_); EXPECT_LE(fabs(delta), eps_);
// TODO(TJ): uncomment me when batch norm ready if (isBN) {
// if (isBN) { // the other two inputs in batch norm are for moving mean and var
// // the other two inputs in batch norm are for moving mean and var // do not have grad to compare
// break; break;
// } }
} }
} }
...@@ -308,10 +313,14 @@ double MKLDNNTester::compareVector(const VectorPtr& v1, const VectorPtr& v2) { ...@@ -308,10 +313,14 @@ double MKLDNNTester::compareVector(const VectorPtr& v1, const VectorPtr& v2) {
void MKLDNNTester::runOnce() { void MKLDNNTester::runOnce() {
// test forward // test forward
randomBotDatas(); randomBotDatas();
dnnLayer_->forward(PASS_TRAIN); dnnLayer_->forward(passType_);
refLayer_->forward(PASS_TRAIN); refLayer_->forward(passType_);
checkForward(); checkForward();
if (passType_ == PASS_TEST) {
return;
}
// test backward // test backward
// simple updater // simple updater
UpdateCallback updateCallback = [](Parameter* para) { UpdateCallback updateCallback = [](Parameter* para) {
...@@ -343,6 +352,7 @@ void MKLDNNTester::run(const TestConfig& dnn, ...@@ -343,6 +352,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
size_t batchSize, size_t batchSize,
size_t inputImgH, size_t inputImgH,
size_t inputImgW, size_t inputImgW,
PassType passType,
bool printDetails, bool printDetails,
size_t iter, size_t iter,
float epsilon) { float epsilon) {
...@@ -361,6 +371,7 @@ void MKLDNNTester::run(const TestConfig& dnn, ...@@ -361,6 +371,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
ih_ = inputImgH; ih_ = inputImgH;
iw_ = inputImgW; iw_ = inputImgW;
passType_ = passType;
log_ = printDetails; log_ = printDetails;
iter_ = iter; iter_ = iter;
eps_ = epsilon; eps_ = epsilon;
......
...@@ -62,12 +62,15 @@ protected: ...@@ -62,12 +62,15 @@ protected:
float eps_; float eps_;
/// input image size, default 1 /// input image size, default 1
size_t ih_, iw_; size_t ih_, iw_;
/// passType, PASS_TRAIN, PASS_TEST or PASS_GC (Gradient Check pass)
PassType passType_;
public: public:
explicit MKLDNNTester(size_t iter = 3, float epsilon = 1e-4) { explicit MKLDNNTester(size_t iter = 3, float epsilon = 1e-4) {
iter_ = iter; iter_ = iter;
eps_ = epsilon; eps_ = epsilon;
log_ = false; log_ = false;
passType_ = PASS_TRAIN;
} }
~MKLDNNTester() {} ~MKLDNNTester() {}
...@@ -78,6 +81,7 @@ public: ...@@ -78,6 +81,7 @@ public:
size_t batchSize, size_t batchSize,
size_t inputImgH = 1, size_t inputImgH = 1,
size_t inputImgW = 1, size_t inputImgW = 1,
PassType passType = PASS_TRAIN,
bool printDetails = false, bool printDetails = false,
size_t iter = 3, size_t iter = 3,
float epsilon = 1e-4); float epsilon = 1e-4);
......
...@@ -212,6 +212,66 @@ TEST(MKLDNNLayer, PoolLayer) { ...@@ -212,6 +212,66 @@ TEST(MKLDNNLayer, PoolLayer) {
testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2}); testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2});
} }
struct testBatchNormDesc {
int bs;
int ic;
int ih, iw;
};
static void getMKLDNNBatchNormConfig(TestConfig& cfg,
const testBatchNormDesc& pm) {
cfg.layerConfig.set_size(pm.ic * pm.ih * pm.iw);
cfg.layerConfig.set_type("mkldnn_batch_norm");
cfg.biasSize = pm.ic;
cfg.inputDefs.push_back(
{INPUT_DATA,
"layer_0",
/* size of input layer= */ size_t(pm.ic * pm.ih * pm.iw),
/* size of weight= */ size_t(pm.ic)});
cfg.inputDefs.push_back(
{INPUT_DATA, "layer_1_moving_mean", 1, size_t(pm.ic)});
cfg.inputDefs.back().isStatic = true;
cfg.inputDefs.push_back({INPUT_DATA, "layer_2_moving_var", 1, size_t(pm.ic)});
cfg.inputDefs.back().isStatic = true;
LayerInputConfig* input = cfg.layerConfig.add_inputs();
// TODO(TJ): uncomment me when refine and support comparing all zeroes vector
// cfg.layerConfig.set_active_type("relu");
cfg.layerConfig.add_inputs();
cfg.layerConfig.add_inputs();
ImageConfig* img_conf = input->mutable_image_conf();
img_conf->set_channels(pm.ic);
img_conf->set_img_size_y(pm.ih);
img_conf->set_img_size(pm.iw);
}
void testBatchNormLayer(const testBatchNormDesc& pm) {
TestConfig dnnConfig;
getMKLDNNBatchNormConfig(dnnConfig, pm);
TestConfig refConfig = dnnConfig;
refConfig.layerConfig.set_type("batch_norm");
// for PASS_TRAIN, use_global_stats always should be false, and batchsize != 1
VLOG(MKLDNN_TESTS) << "check train phase";
dnnConfig.layerConfig.set_use_global_stats(false);
refConfig.layerConfig.set_use_global_stats(false);
MKLDNNTester tester;
tester.run(dnnConfig, refConfig, pm.bs, pm.ih, pm.iw, PASS_TRAIN);
// for PASS_TEST, check use_global_stats true and false, and batchsize 1
VLOG(MKLDNN_TESTS) << "check test phase";
for (auto useGS : {false, true}) {
dnnConfig.layerConfig.set_use_global_stats(useGS);
refConfig.layerConfig.set_use_global_stats(useGS);
MKLDNNTester tester;
for (auto bs : {pm.bs, 1}) {
tester.run(dnnConfig, refConfig, bs, pm.ih, pm.iw, PASS_TEST);
}
}
}
TEST(MKLDNNLayer, BatchNormLayer) {
testBatchNormLayer({4, 10, 6, 6});
testBatchNormLayer({16, 32, 16, 16});
}
struct testActDesc { struct testActDesc {
int bs, ic, ih, iw; int bs, ic, ih, iw;
}; };
......
...@@ -91,6 +91,11 @@ public: ...@@ -91,6 +91,11 @@ public:
const MKLDNNMatrixPtr& dst, const MKLDNNMatrixPtr& dst,
bool checkData = true); bool checkData = true);
void copyFrom(const Matrix& src) {
// TODO(TJ): reorder data if this format is not nchw or x
m_->copyFrom(src);
}
public: public:
/** /**
* Reorder this MKLDNNMatrix from other format. * Reorder this MKLDNNMatrix from other format.
......
...@@ -60,7 +60,7 @@ public: ...@@ -60,7 +60,7 @@ public:
*/ */
inline real* get(int row) const { inline real* get(int row) const {
if (preallocatedBuf_) { if (preallocatedBuf_) {
CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize()); CHECK_LE((row)*width_ * sizeof(real), preallocatedBuf_->getSize());
return reinterpret_cast<real*>(preallocatedBuf_->getBuf()) + row * width_; return reinterpret_cast<real*>(preallocatedBuf_->getBuf()) + row * width_;
} else { } else {
CHECK_LE((row + 1) * width_, rowStore_.size()); CHECK_LE((row + 1) * width_, rowStore_.size());
......
...@@ -54,6 +54,5 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, ...@@ -54,6 +54,5 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream); cudaStream_t stream);
#endif #endif
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -131,6 +131,7 @@ set(DEPS_OPS ...@@ -131,6 +131,7 @@ set(DEPS_OPS
pool_op pool_op
pool_with_index_op pool_with_index_op
conv_op conv_op
sequence_conv_op
lstm_op) lstm_op)
...@@ -139,10 +140,11 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc ...@@ -139,10 +140,11 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
op_library(conv_op DEPS vol2col) op_library(conv_op DEPS vol2col)
op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling) op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling) op_library(pool_with_index_op DEPS pooling)
op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
...@@ -157,3 +159,4 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) ...@@ -157,3 +159,4 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
...@@ -449,9 +449,13 @@ REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker<float>, ...@@ -449,9 +449,13 @@ REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker<float>,
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, \ act_type, \
ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>); \ ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>, \
REGISTER_OP_CPU_KERNEL(act_type##_grad, \ ops::ActivationKernel<paddle::platform::CPUPlace, \
ops::functor<double>>); \
REGISTER_OP_CPU_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CPUPlace, \ ops::ActivationGradKernel<paddle::platform::CPUPlace, \
ops::grad_functor<float>>); ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -20,9 +20,13 @@ namespace ops = paddle::operators; ...@@ -20,9 +20,13 @@ namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \ REGISTER_OP_GPU_KERNEL( \
act_type, \ act_type, \
ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>); \ ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>, \
REGISTER_OP_GPU_KERNEL(act_type##_grad, \ ops::ActivationKernel<paddle::platform::GPUPlace, \
ops::functor<double>>); \
REGISTER_OP_GPU_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::GPUPlace, \ ops::ActivationGradKernel<paddle::platform::GPUPlace, \
ops::grad_functor<float>>); ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
...@@ -210,8 +210,8 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -210,8 +210,8 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval(); auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval(); auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
y.device(d) = x * (temp1 + temp2); y.device(d) = x * (temp1 + temp2);
} }
}; };
...@@ -226,8 +226,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -226,8 +226,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval(); auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval(); auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>(); dx.device(d) = dy * (temp1 + temp2).template cast<T>();
} }
}; };
...@@ -243,9 +243,10 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -243,9 +243,10 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
auto temp1 = (x > lambda).template cast<T>().eval(); auto lambdaT = static_cast<T>(lambda);
auto temp2 = (x < -lambda).template cast<T>().eval(); auto temp1 = (x > lambdaT).template cast<T>().eval();
y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda); auto temp2 = (x < -lambdaT).template cast<T>().eval();
y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
} }
}; };
...@@ -257,8 +258,9 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -257,8 +258,9 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = (x > lambda).template cast<T>().eval(); auto lambdaT = static_cast<T>(lambda);
auto temp2 = (x < -lambda).template cast<T>().eval(); auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>(); dx.device(d) = dy * (temp1 + temp2).template cast<T>();
} }
}; };
...@@ -362,7 +364,8 @@ struct BReluFunctor : public BaseActivationFunctor<T> { ...@@ -362,7 +364,8 @@ struct BReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max); y.device(d) =
x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
} }
}; };
...@@ -375,7 +378,9 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -375,7 +378,9 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast<T>(); dx.device(d) = dy *
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
.template cast<T>();
} }
}; };
...@@ -390,7 +395,8 @@ struct Relu6Functor : public BaseActivationFunctor<T> { ...@@ -390,7 +395,8 @@ struct Relu6Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)).cwiseMin(threshold); y.device(d) =
x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
} }
}; };
...@@ -402,8 +408,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -402,8 +408,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dx.device(d) = dy *
dy * ((x > static_cast<T>(0)) * (x < threshold)).template cast<T>(); ((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
.template cast<T>();
} }
}; };
...@@ -463,7 +470,8 @@ struct SoftReluFunctor : public BaseActivationFunctor<T> { ...@@ -463,7 +470,8 @@ struct SoftReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold); auto tmp = static_cast<T>(threshold);
auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
y.device(d) = (static_cast<T>(1) + temp.exp()).log(); y.device(d) = (static_cast<T>(1) + temp.exp()).log();
} }
}; };
...@@ -476,7 +484,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -476,7 +484,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval(); auto tmp = static_cast<T>(threshold);
auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp; dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
} }
}; };
...@@ -490,7 +499,7 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> { ...@@ -490,7 +499,7 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(alpha * x); y.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
} }
}; };
...@@ -502,7 +511,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -502,7 +511,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = alpha * (x < static_cast<T>(0)).template cast<T>().eval(); auto temp1 = static_cast<T>(alpha) *
(x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval(); auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>(); dx.device(d) = dy * (temp1 + temp2).template cast<T>();
} }
...@@ -517,9 +527,9 @@ struct ELUFunctor : public BaseActivationFunctor<T> { ...@@ -517,9 +527,9 @@ struct ELUFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
y.device(d) = y.device(d) = x.cwiseMax(static_cast<T>(0)) +
x.cwiseMax(static_cast<T>(0)) + (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
(alpha * (x.exp() - static_cast<T>(1))).cwiseMin(static_cast<T>(0)); .cwiseMin(static_cast<T>(0));
} }
}; };
...@@ -531,9 +541,9 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -531,9 +541,9 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (x > static_cast<T>(0)).template cast<T>() + dy * (y + static_cast<T>(alpha)) *
dy * (y + alpha) * (x < static_cast<T>(0)).template cast<T>(); (x < static_cast<T>(0)).template cast<T>();
} }
}; };
...@@ -545,7 +555,7 @@ struct PowFunctor : public BaseActivationFunctor<T> { ...@@ -545,7 +555,7 @@ struct PowFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
y.device(d) = x.pow(factor); y.device(d) = x.pow(static_cast<T>(factor));
} }
}; };
...@@ -557,7 +567,8 @@ struct PowGradFunctor : public BaseActivationFunctor<T> { ...@@ -557,7 +567,8 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * factor * x.pow(factor - static_cast<T>(1)); dx.device(d) = dy * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1)));
} }
}; };
...@@ -571,7 +582,8 @@ struct STanhFunctor : public BaseActivationFunctor<T> { ...@@ -571,7 +582,8 @@ struct STanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
y.device(d) = scale_b * (scale_a * x).tanh(); y.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
} }
}; };
...@@ -585,8 +597,10 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -585,8 +597,10 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = (scale_a * x).tanh() * (scale_a * x).tanh(); auto a = static_cast<T>(scale_a);
dx.device(d) = dy * scale_a * scale_b * (static_cast<T>(1) - temp); auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dy * a * b * (static_cast<T>(1) - temp);
} }
}; };
...@@ -599,7 +613,8 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor<T> { ...@@ -599,7 +613,8 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Y y) const {
y.device(d) = (x > static_cast<T>(threshold)).template cast<T>() * x; auto th = static_cast<T>(threshold);
y.device(d) = (x > th).template cast<T>() * x;
} }
}; };
...@@ -612,7 +627,8 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -612,7 +627,8 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(threshold)).template cast<T>(); auto th = static_cast<T>(threshold);
dx.device(d) = dy * (x > th).template cast<T>();
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/batch_norm_op.h"
#include <cfloat>
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
void ExtractNCWHD(const framework::DDim &dims,
const TensorFormat &tensor_format, int *N, int *C, int *H,
int *W, int *D) {
*N = dims[0];
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
: 1;
}
template <typename T>
class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
VLOG(1) << "Setting descriptors.";
std::vector<int> dims;
std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * D * C, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), saved_mean, 0);
functor(ctx.device_context(), saved_variance, 0);
// FIXME(qiao) should not set zero self
functor(ctx.device_context(), mean_out, 0);
functor(ctx.device_context(), variance_out, 0);
auto handle = ctx.cuda_device_context().cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths.
if (is_test) {
// only when test we use input to do computation.
const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance");
// Run inference mode.
PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(est_mean->dims()[0], C);
PADDLE_ENFORCE_EQ(est_var->dims()[0], C);
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardInference(
handle,
// Note: PERSISTENT not implemented for inference
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
} else {
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// initialize them.
double this_factor = 1. - momentum;
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), bias->template data<T>(), this_factor,
mean_out->template mutable_data<T>(ctx.GetPlace()),
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean->template mutable_data<T>(ctx.GetPlace()),
saved_variance->template mutable_data<T>(ctx.GetPlace())));
}
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};
template <typename T>
class BatchNormGradKernel<platform::GPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
ctx.cuda_device_context().cudnn_handle(), mode_,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
x->template data<T>(), data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(),
d_scale->template mutable_data<T>(ctx.GetPlace()),
d_bias->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean_data, saved_var_data));
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(batch_norm,
ops::BatchNormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::GPUPlace, float>);
...@@ -197,11 +197,11 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> { ...@@ -197,11 +197,11 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
// convolution backward weight operator: im2col + gemm // convolution backward weight operator: im2col + gemm
int in_step = input_channels / groups; int in_step = input_channels / groups;
int out_step = output_channels / groups; int out_step = output_channels / groups;
math::SetConstant<Place, T> set_zero;
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad); set_zero(context.device_context(), input_grad, static_cast<T>(0));
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
...@@ -230,8 +230,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> { ...@@ -230,8 +230,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad; Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_); set_zero(context.device_context(), filter_grad, static_cast<T>(0));
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
...@@ -410,11 +409,11 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> { ...@@ -410,11 +409,11 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
// convolution backward weight operator: vol2col + gemm // convolution backward weight operator: vol2col + gemm
int in_step = input_channels / groups; int in_step = input_channels / groups;
int out_step = output_channels / groups; int out_step = output_channels / groups;
math::SetConstant<Place, T> set_zero;
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad); set_zero(context.device_context(), input_grad, static_cast<T>(0));
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
...@@ -443,8 +442,7 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> { ...@@ -443,8 +442,7 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad; Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_); set_zero(context.device_context(), filter_grad, static_cast<T>(0));
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
......
...@@ -162,6 +162,8 @@ or not. But the output only shares the LoD with input `X`. ...@@ -162,6 +162,8 @@ or not. But the output only shares the LoD with input `X`.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
cross_entropy_grad, ops::CrossEntropyGradientOp); cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>); REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>,
ops::CrossEntropyOpKernel<double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad, REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<float>); ops::CrossEntropyGradientOpKernel<float>,
ops::CrossEntropyGradientOpKernel<double>);
...@@ -108,6 +108,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> { ...@@ -108,6 +108,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>); REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad, REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>); ops::CrossEntropyGradientOpCUDAKernel<float>,
ops::CrossEntropyGradientOpCUDAKernel<double>);
...@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == 1) { if (ctx->Attrs().Get<bool>("is_training") == true) {
ctx->SetOutputDim("Mask", x_dims); ctx->SetOutputDim("Mask", x_dims);
} }
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
DropoutOpMaker(framework::OpProto* proto, DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.") AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f); .SetDefault(.5f);
AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true); AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
...@@ -69,7 +69,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -69,7 +69,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1, PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true"); "GradOp is only callable when is_training is true");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
...@@ -77,8 +77,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -77,8 +77,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null."); "Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims, PADDLE_ENFORCE_EQ(x_dims, out_dims,
......
...@@ -33,7 +33,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y = context.Output<Tensor>("Out"); auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
if (context.Attr<bool>("is_training")) { if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
...@@ -41,7 +41,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -41,7 +41,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
std::minstd_rand engine; std::minstd_rand engine;
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<AttrType> dist(0, 1); std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims()); size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) { if (dist(engine) < dropout_prob) {
......
...@@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx); dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
dst_item.set_lod(src_item.lod()); dst_item.set_lod(src_item.lod());
VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name; VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fill_constant_batch_size_like_op.h"
namespace paddle {
namespace operators {
class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Input"),
"Input(Input) of FillConstantBatchSizeLikeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FillConstantBatchSizeLikeOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto dims = framework::make_ddim(shape_int64);
dims[0] = ctx->GetInputDim("Input")[0];
ctx->SetOutputDim("Out", dims);
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
}
};
class FillConstantBatchSizeLikeOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddInput("Input",
"(Tensor) Tensor "
"whose first dimension is used to specify the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddComment(R"DOC(Fill up a variable with specified constant value.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp,
ops::FillConstantBatchSizeLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_constant_batch_size_like_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value");
auto out_eigen = framework::EigenVector<T>::Flatten(*out);
auto place = ctx.GetEigenDevice<Place>();
out_eigen.device(place) = out_eigen.constant(static_cast<T>(value));
}
};
} // namespace operators
} // namespace paddle
...@@ -64,5 +64,6 @@ namespace ops = paddle::operators; ...@@ -64,5 +64,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp, REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
ops::FillConstantOpMaker); ops::FillConstantOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_constant, fill_constant, ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>); ops::FillConstantOpKernel<paddle::platform::CPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int>);
...@@ -18,5 +18,6 @@ ...@@ -18,5 +18,6 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
fill_constant, fill_constant, ops::FillConstantOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, float>); ops::FillConstantOpKernel<paddle::platform::GPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int>);
...@@ -25,7 +25,7 @@ class FillConstantOpKernel : public framework::OpKernel<T> { ...@@ -25,7 +25,7 @@ class FillConstantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<T>("value"); auto value = ctx.Attr<float>("value");
auto out_eigen = framework::EigenVector<T>::Flatten(*out); auto out_eigen = framework::EigenVector<T>::Flatten(*out);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
......
...@@ -171,8 +171,7 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -171,8 +171,7 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
auto bias = Input("Bias"); if (ctx->HasInput("Bias")) {
if (bias != framework::kEmptyVarName) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0]; int bias_height = bias_dims[0];
int bias_width = bias_dims[1]; int bias_width = bias_dims[1];
...@@ -203,6 +202,8 @@ namespace ops = paddle::operators; ...@@ -203,6 +202,8 @@ namespace ops = paddle::operators;
REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad, REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad,
ops::GRUUnitGradOp); ops::GRUUnitGradOp);
REGISTER_OP_CPU_KERNEL(gru_unit, REGISTER_OP_CPU_KERNEL(gru_unit,
ops::GRUUnitKernel<paddle::platform::CPUPlace, float>); ops::GRUUnitKernel<paddle::platform::CPUPlace, float>,
ops::GRUUnitKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::CPUPlace, float>); gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::CPUPlace, float>,
ops::GRUUnitGradKernel<paddle::platform::CPUPlace, double>);
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gru_unit, REGISTER_OP_GPU_KERNEL(gru_unit,
ops::GRUUnitKernel<paddle::platform::GPUPlace, float>); ops::GRUUnitKernel<paddle::platform::GPUPlace, float>,
ops::GRUUnitKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::GPUPlace, float>); gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::GPUPlace, float>,
ops::GRUUnitGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/l1_norm_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class L1NormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
ctx->SetOutputDim("Out", {1});
}
};
class L1NormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class L1NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
L1NormOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of l1_norm op.");
AddOutput("Out", "(Scalar) The output of l1_norm op.");
AddComment(R"DOC(
L1 Norm Operator.
Computes the L1 norm of a tensor.
Out = sum (abs(X))
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(l1_norm, ops::L1NormOp, ops::L1NormOpMaker, l1_norm_grad,
ops::L1NormGradOp);
REGISTER_OP_CPU_KERNEL(l1_norm,
ops::L1NormKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
l1_norm_grad, ops::L1NormGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/l1_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(l1_norm,
ops::L1NormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
l1_norm_grad, ops::L1NormGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
// Out = sum(abs(X))
template <typename Place, typename T>
class L1NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
framework::Tensor *Out = context.Output<framework::Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto place = context.GetEigenDevice<Place>();
out.device(place) = x.abs().sum();
}
};
// dX = dout * sign(X)
template <typename Place, typename T>
class L1NormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *x = context.Input<framework::Tensor>("X");
const framework::Tensor *d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(d_out->numel() == 1, "L1 Norm Gradient should be scalar");
framework::Tensor *dx =
context.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(context.GetPlace());
auto x_eigen = framework::EigenVector<T>::Flatten(*x);
auto d_out_eigen = framework::EigenVector<T>::Flatten(*d_out);
auto dx_eigen = framework::EigenVector<T>::Flatten(*dx);
auto place = context.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> x_dsize(x->numel());
dx_eigen.device(place) = d_out_eigen.broadcast(x_dsize) * x_eigen.sign();
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include <fstream>
namespace paddle {
namespace operators {
class LoadOp : public framework::OperatorBase {
public:
LoadOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
filename);
auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name);
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
uint32_t version;
fin.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
fin.read(reinterpret_cast<char *>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(),
std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void *buf;
platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) {
case framework::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case framework::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case framework::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case framework::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:
PADDLE_THROW("DataType %d not supported", desc.data_type());
}
fin.read(static_cast<char *>(buf), tensor->memory_size());
}
{ // read lod
uint64_t lod_level;
fin.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
fin.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
}
auto place = dev_ctx.GetPlace();
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
framework::LoDTensor cpu_tensor;
cpu_tensor.ShareDataWith(*tensor);
cpu_tensor.set_lod(tensor->lod());
// reset tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
tensor->CopyFrom(cpu_tensor, place, dev_ctx);
}
}
};
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoadOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The tensor need to be loaded");
AddComment(R"DOC(Load Operator
Load operator will load a tensor variable from disk file.
)DOC");
AddAttr<std::string>("file_path",
"Variable will be loaded from \"file_path\".")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lrn_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class LRNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LRNOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LRNOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MidOut"),
"MidOut(Out) of LRNOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
ctx->SetOutputDim("Out", x_dim);
ctx->SetOutputDim("MidOut", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LRNOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", R"DOC(
(Tensor) The input of LRN operator. It must be a 4D tenor with NCHW format.
)DOC");
AddOutput("Out",
"(Tensor) The output of LRN operator, which is also the 4D "
"tensor with NCHW format.");
AddOutput("MidOut", R"Doc(
(Tensor)Middle result of lrn op.It's computed in forward process
and also used in backward process.
)Doc");
AddAttr<int>("n", R"DOC(
(int, default 5)n is “adjacent” kernel maps at the same spatial position.
)DOC")
.SetDefault(5)
.GreaterThan(0);
AddAttr<T>("k", R"DOC(
(float, default 2.0)k is the bias.
)DOC")
.SetDefault(2.0)
.GreaterThan(0.0);
AddAttr<T>("alpha", R"DOC(
(float, default 0.0001)alpha is the scale number.
)DOC")
.SetDefault(0.0001)
.GreaterThan(0.0);
AddAttr<T>("beta", R"DOC(
(float, default 0.75)beta is the power number.
)DOC")
.SetDefault(0.75)
.GreaterThan(0.0);
AddComment(R"DOC(
Local Response Normalization.
This Function comes from the paper
"ImageNet Classification with Deep Convolutional Neural Networks".
The original formula is:
Input(i, x, y)
Output(i, x, y) = ----------------------------------------------
-- upper
(k + alpha * > (Input(j, x, y))^2) ^ (beta)
-- j = lower
upper is `min(C, c + n/2)`
lower if `max(0, c - n/2)`
Function implementation:
inputs and outpus is NCHW format, while input.shape.ndims() is equal 4.
And the meaning of each dimension(0-3) is respectively batch size,
feature maps, rows and columns.
Input and Output in the above formula is for each map(i) of one image, and
Input(i, x, y), Output(i, x, y) represents an element in an image.
C is the number of feature maps of one image, and n is a hyper-parameters
is configured when Function is initialized. The sum in the denominator
is the sum of the same position in the neighboring maps.
)DOC");
}
};
class LRNOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("MidOut")),
"Input(MidOut@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lrn, ops::LRNOp, ops::LRNOpMaker<float>, lrn_grad, ops::LRNOpGrad);
REGISTER_OP_CPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(lrn_grad,
ops::LRNGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/lrn_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(lrn_grad,
ops::LRNGradKernel<paddle::platform::GPUPlace, float>);
此差异已折叠。
...@@ -9,6 +9,7 @@ if(WITH_GPU) ...@@ -9,6 +9,7 @@ if(WITH_GPU)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
else() else()
...@@ -18,6 +19,7 @@ else() ...@@ -18,6 +19,7 @@ else()
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
endif() endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/context_project.h"
namespace paddle {
namespace operators {
namespace math {
template class ContextProjectFunctor<platform::CPUPlace, float>;
template class ContextProjectFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -12,28 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,28 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
syntax = "proto2"; #define EIGEN_USE_GPU
option optimize_for = LITE_RUNTIME;
package paddle.framework; #include "paddle/operators/math/context_project.h"
import "framework.proto"; namespace paddle {
namespace operators {
/** namespace math {
* This file contains necessary information for model, checkpoint.
* etc. template class ContextProjectFunctor<platform::GPUPlace, float>;
*/ template class ContextProjectFunctor<platform::GPUPlace, double>;
message LoDInfo { repeated int64 level = 1; } } // namespace math
} // namespace operators
/** } // namespace paddle
* Save the LoDTensorDesc information through LoDTensorProto, its data memory
* is copyed to c buffer immediately. See model_format.md for details.
*/
message LoDTensorProto {
optional DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
repeated LoDInfo levels = 3;
optional int32 lod_level = 4 [ default = 0 ];
optional int32 version = 5;
}
此差异已折叠。
...@@ -54,6 +54,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> { ...@@ -54,6 +54,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
}; };
template class CrossEntropyFunctor<platform::CPUPlace, float>; template class CrossEntropyFunctor<platform::CPUPlace, float>;
template class CrossEntropyFunctor<platform::CPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -71,7 +71,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker { ...@@ -71,7 +71,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker); REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp); REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean, REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<paddle::platform::CPUPlace, float>,
ops::MeanKernel<paddle::platform::CPUPlace, float>); ops::MeanKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(mean_grad, REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>); ops::MeanGradKernel<paddle::platform::CPUPlace, float>,
ops::MeanGradKernel<paddle::platform::CPUPlace, double>);
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册