提交 16a39d24 编写于 作者: D Dong Zhihong

fix conflict

......@@ -12,27 +12,25 @@ The topology is saved as a plain text in a detailed self-contain protobuf file.
The parameters are saved as a binary file. As we all know, the protobuf message has a limit of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We have done a [benchmark experiment](https://github.com/PaddlePaddle/Paddle/pull/4610), which shows that protobuf is not fit for the task.
As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**|
As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
The table below shows a tensor's byte view in detail. Note that all the signed values are written in the little-endian format.
```text
[offset] [type] [description]
0004 4 bytes integer HeaderLength, the length of LoDTensorDesc
0008 4 bytes integer ContentLength, the length of LodTensor Buffer
0009 1 bytes char TensorDesc
00010 1 bytes char TensorDesc
...
00100 1 bytes char TensorValue
00101 1 bytes char TensorValue
00102 1 bytes char TensorValue ..
...
```
|field name | type | description |
| --- | --- | --- |
| version | uint32_t | Version of saved file. Always 0 now. |
| tensor desc length | uint32_t | TensorDesc(Protobuf message) length in bytes. |
| tensor desc | void* | TensorDesc protobuf binary message |
| tensor data | void* | Tensor's data in binary format. The length of `tensor_data` is decided by `TensorDesc.dims()` and `TensorDesc.data_type()` |
| lod_level | uint64_t | Level of LoD |
| length of lod[0] | uint64_t | [Optional] length of lod[0] in bytes. |
| data of lod[0] | uint64_t* | [Optional] lod[0].data() |
| ... | ... | ... |
## Summary
- We introduce a model format.
- The `ProgramDesc` describe the model **topology**.
- The model represented by its forward-pass computation procedure is saved in a **ProgramDesc** protobuf message.
- A bunch of specified format binary tensors describe the **parameters**.
# Regularization in PaddlePaddle
## Introduction to Regularization
A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. Many strategies are used by machine learning practitioners to reduce the test error, possibly at the expense of increased training error. These strategies are collectively known as **regularization**.
A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. A frequently faced problem is the problem of **overfitting**, where the model does not make reliable predictions on new unseen data. **Regularization** is the process of introducing additional information in order to prevent overfitting. This is usually done by adding extra penalties to the loss function that restricts the parameter spaces that an optimization algorithm can explore.
### Parameter Norm Penalties
Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows:
......@@ -18,52 +18,21 @@ The most commonly used norm penalties are the L2 norm penalty and the L1 norm pe
##### L1 Regularization
<img src="./images/l1_regularization.png" align="center"/><br/>
A much more detailed mathematical background of reguilarization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
A much more detailed mathematical background of regularization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
## Regularization Survey
## How to do Regularization in PaddlePaddle
On surveying existing frameworks like Tensorflow, PyTorch, Caffe, etc, it can be seen that there are 2 common approaches of doing regularization:
1. Making regularization a part of the optimizer using an attribute like `weight_decay` that is used to control the scale of the L2 Penalty. This approach is used in PyTorch as follows:
```python
opt = torch.optim.SGD(params, lr=0.2, weight_decay=0.2)
```
At every optimization step, this code will add the gradient of the L2 Norm of the params to the gradient of the params with respect to the loss function. This can seen in the following code snippet:
```python
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
```
This is a very restyrictive way of doing regularization and does not give the users enough flexibility.
**Advantages**:
- It is easy to implement for us.
- Faster execution of backward. However, it can be done manually by advanced users too.
**Disadvantages**:
- Not flexible for other regularizations such as L1/L0 regularization.
- Does not allow for different regularization coefficient for different parameters. For example, in most models, ony the weight matrices are regularized and the bias vectors are unregularized.
- Tightly coupled optimizer and regularization implementation.
2. Adding regularization ops to the graph through Python API. This approach is used by Tensorflow and Caffe. Using this approach, we manually add regularization ops to the graph and then add the regularization loss to the final loss function before sending them to the optimizer.
**Advantages**:
- Allows for greater flexibility to the users of Paddle. Using this approach, the users can put different regularization to different parameters and also choose parameters that are not a part of regularization.
- Makes it easy for the users to customize and extend the framework.
**Disadvantages**:
- Implementation requires comprehensive design and time.
A detailed survey of regularization in various deep learning frameworks can be found [here](https://github.com/PaddlePaddle/Paddle/wiki/Regularization-Survey).
## Proposal for Regularization in PaddlePaddle
### Low-Level implementation
In the new design, we propose to create new operations for regularization. For now, we can add 2 ops thgat correspond to the most frequently used regularizations:
In the new design, we propose to create new operations for regularization. For now, we can add 2 ops that correspond to the most frequently used regularizations:
- L2_regularization_op
- L1_regularization_op
These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate Cpu and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes. other than L1 and L2 norm penalties.
These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate CPU and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes other than L1 and L2 norm penalties.
The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API.
......@@ -94,7 +63,7 @@ Since we want to create the regularization ops in a lazy manner, the regularizat
#### High-level API
In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we lso need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we also need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
......
......@@ -67,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Info("Could not find the pserver checkpoint.")
log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
......@@ -99,7 +99,7 @@ func main() {
candy.Must(err)
go func() {
log.Info("starting pserver", log.Ctx{"port": *port})
log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
......
......@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0])
}
var cptr (*C.uchar)
if len(c) > 0 {
cptr = (*C.uchar)(&c[0])
} else {
log.Error("empty config", "param name", paramWithConfigs.Param.Name)
}
o.config = c
o.opt = C.paddle_create_optimizer(
(*C.uchar)(&c[0]),
cptr,
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
......
......@@ -17,12 +17,11 @@ package pserver
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
......@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
// RPC error message.
const (
......@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
MD5 string `json:"md5"`
CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"`
}
......@@ -92,7 +91,7 @@ type Service struct {
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
client KVStore
mu sync.Mutex
optMap map[string]*optimizer
......@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State []byte
}
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
type KVStore interface {
GetKey(key string, timeout time.Duration) ([]byte, error)
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}
func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
......@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}
// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")
......@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return nil, err
}
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
crc32 := crc32.ChecksumIEEE(content)
if crc32 != cpMeta.CRC32 {
return nil, errors.New(WrongChecksum)
}
......@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if err = dec.Decode(&cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
......@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
close(s.initialized)
}
return s, nil
}
......@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t {
err := s.checkpoint()
if err != nil {
log.Error("finish init params error", log.Ctx{"error": err})
log.Error("checkpoint error", log.Ctx{"error": err})
}
}
}()
......@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil
}
......@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Info("Do not have existing checkpoint.")
log.Info("old meta not found, skip removing old meta")
err = nil
} else if err == nil {
log.Info("removing old meta")
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
}
if err != nil {
return
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
crc32 := crc32.ChecksumIEEE(buf.Bytes())
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
MD5: md5,
CRC32: crc32,
Path: p,
}
......@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
}
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
return
}
package pserver
import (
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const testDir = "./test_data"
type myKV struct {
m map[string][]byte
}
func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
if m.m == nil {
m.m = make(map[string][]byte)
}
return m.m[key], nil
}
func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
if m.m == nil {
m.m = make(map[string][]byte)
}
m.m[key] = value
return nil
}
func TestCheckpoint(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
err = s.checkpoint()
assert.Nil(t, err)
_, err = LoadCheckpoint(kv, 0)
assert.Nil(t, err)
}
func float32ToByte(f float32) []byte {
var buf bytes.Buffer
err := binary.Write(&buf, binary.LittleEndian, f)
if err != nil {
fmt.Println("binary.Write failed:", err)
}
return buf.Bytes()
}
func TestCheckpointWithData(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
var content []byte
for i := 0; i < 50000; i++ {
content = append(content, float32ToByte(float32(i))...)
}
p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
assert.Nil(t, err)
err = s.FinishInitParams(0, nil)
assert.Nil(t, err)
var p2 Parameter
err = s.GetParam(p1.Name, &p2)
assert.Nil(t, err)
assert.Equal(t, p1, p2)
err = s.checkpoint()
assert.Nil(t, err)
cp, err := LoadCheckpoint(kv, 0)
assert.Nil(t, err)
s1, err := NewService(0, time.Hour, testDir, kv, cp)
assert.Nil(t, err)
var p3 Parameter
err = s1.GetParam(p1.Name, &p3)
assert.Nil(t, err)
assert.Equal(t, p1, p3)
}
......@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait()
}
func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}
......@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config;
paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
return kPD_PROTOBUF_ERROR;
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
!modelConfig.IsInitialized()) {
return kPD_PROTOBUF_ERROR;
}
} else {
modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create(
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
for (auto& para : parameters) {
para->load(is);
......
# ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(saver_proto SRCS framework.proto saver.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
......@@ -10,7 +9,7 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor saver_proto framework_proto)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
......
......@@ -15,6 +15,7 @@
#pragma once
#include <typeindex>
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
......
......@@ -13,7 +13,6 @@
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/saver.pb.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
......@@ -106,6 +105,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const {
return lod_[level][idx + 1] - lod_[level][idx];
}
size_t LoDTensor::NumInstancesInElement(size_t level, size_t idx) const {
PADDLE_ENFORCE_LT(level, NumLevels());
PADDLE_ENFORCE_LT(idx, NumElements(level));
auto abs_lod = ToAbsOffset(lod());
size_t begin = abs_lod[level][idx];
size_t end = abs_lod[level][idx + 1];
return end - begin;
}
void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) {
auto new_lod = framework::SliceLevels(lod_, level_begin, level_end);
lod_ = new_lod;
......@@ -117,144 +125,15 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
PADDLE_ENFORCE_LT(elem_begin, NumElements(level));
PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1);
auto abs_lod = framework::ToAbsOffset(lod());
auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end);
lod_ = new_lod;
}
std::string LoDTensor::SerializeToString() const {
LoDTensorProto desc;
// set data_type
if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL);
if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16);
if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32);
if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64);
// FIXME(dzh): there is no fp16 in standard c++
if (this->type() == typeid(float)) // NOLINT
desc.set_data_type(DataType::FP32);
if (this->type() == typeid(double)) // NOLINT
desc.set_data_type(DataType::FP64);
for (int i = 0; i < dims().size(); ++i) {
desc.add_dims(dims()[i]);
}
// set lod information
desc.set_lod_level(this->NumLevels());
for (size_t i = 0; i < this->NumLevels(); ++i) {
LoDInfo* lod = desc.add_levels();
for (size_t j = 0; j < lod_[i].size(); ++j) {
lod->add_level(lod_[i][j]);
}
}
desc.set_version(0);
std::string desc_bytes = desc.SerializeAsString();
// FIXME(dzh) : implement fix chunk size buffer.
size_t DESC_SIZE = desc_bytes.size();
size_t DATA_SIZE = holder_->size() - offset_;
const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t);
char* buffer =
static_cast<char*>(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE));
// format: desc_size data_size, desc_bytes, data_bytes.
platform::CPUPlace src_place;
platform::CPUPlace dst_place;
memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t));
memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE,
sizeof(size_t));
memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place,
desc_bytes.c_str(), desc_bytes.size());
PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!");
platform::Place place = holder_->place();
int element_width = holder_->size() / this->numel();
if (platform::is_cpu_place(place)) {
memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
boost::get<platform::CPUPlace>(place),
static_cast<char*>(holder_->ptr()) + offset_ / element_width,
DATA_SIZE);
}
#ifdef PADDLE_WITH_GPU
if (platform::is_gpu_place(place)) {
memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
boost::get<platform::GPUPlace>(place),
static_cast<char*>(holder_->ptr()) + offset_ / element_width,
DATA_SIZE);
}
#endif
std::string ret(buffer, BUFFER_SIZE);
memory::Free(platform::CPUPlace(), buffer);
return ret;
// slice the underlying tensor
size_t begin = abs_lod[level][elem_begin];
size_t end = abs_lod[level][elem_end];
PADDLE_ENFORCE_LT(begin, end, "Cannot shrink, the result tensor is empty.");
ShareDataWith(Slice(begin, end));
}
void LoDTensor::DeserializeFromString(const std::string& s,
const platform::Place& dst_place) {
size_t DESC_SIZE, BUFFER_SIZE;
platform::CPUPlace src_place;
memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t));
memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t),
sizeof(size_t));
const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2;
// parse LoDTensorDesc
LoDTensorProto desc;
desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE);
std::vector<int64_t> dims;
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
this->Resize(make_ddim(dims));
// parse data type
void* ptr = nullptr;
if (desc.data_type() == DataType::BOOL)
ptr = this->mutable_data<bool>(dst_place);
if (desc.data_type() == DataType::INT16)
ptr = this->mutable_data<int16_t>(dst_place);
if (desc.data_type() == DataType::INT32)
ptr = this->mutable_data<int32_t>(dst_place);
if (desc.data_type() == DataType::INT64)
ptr = this->mutable_data<int64_t>(dst_place);
// FIXME(dzh): there is no fp16 in standard c++
if (desc.data_type() == DataType::FP32)
ptr = this->mutable_data<float>(dst_place);
if (desc.data_type() == DataType::FP64)
ptr = this->mutable_data<double>(dst_place);
LoD lod;
std::vector<size_t> levels;
for (int i = 0; i < desc.levels().size(); ++i) {
auto current_level = desc.levels()[i].level();
std::copy(current_level.begin(), current_level.end(),
std::back_inserter(levels));
lod.emplace_back(levels);
levels.clear();
}
this->set_lod(lod);
if (platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), ptr, src_place,
s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
}
#ifdef PADDLE_WITH_GPU
if (platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), ptr, src_place,
s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
}
#endif
}
} // namespace framework
} // namespace paddle
......@@ -85,7 +85,9 @@ class LoDTensor : public Tensor {
void set_lod(const LoD& lod) { lod_ = lod; }
LoD lod() const { return lod_; }
const LoD& lod() const { return lod_; }
LoD* mutable_lod() { return &lod_; }
/*
* Get the start offset and end offset of an element from LoD.
......@@ -122,6 +124,12 @@ class LoDTensor : public Tensor {
*/
size_t NumElements(size_t level, size_t idx) const;
/*
* Get the number of instances in the underlying tensor in the `idx`-th
* element.
*/
size_t NumInstancesInElement(size_t level, size_t idx) const;
/*
* Shrink levels[level_begin:level_end]
*/
......@@ -133,29 +141,45 @@ class LoDTensor : public Tensor {
*/
void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end);
/**
* @brief Serialize tensor to char bytes.
* Please check model_format.md for the format detail.
* NOTE: GPUTensor will copy data to cpu implicitly.
* @return return string
*/
// FIXME(dzh) : Currently, this interface should only be used in
// save/restore model and checkpoint. ParameterServer do not use shape
// information to do the optimization, as a result, when we serialize
// parameter/gradient to string, we should serialize the tensor
// to string in the ps trainer instead of LoDTensor.
std::string SerializeToString() const;
/**
* @brief Deserialize char bytes to tensor.
* @return return string
*/
void DeserializeFromString(const std::string& s,
const platform::Place& dst_place);
private:
LoD lod_;
};
/*
* Expand the `source` to fit the LoD of `lod`. For example, a `source`
* LoDTensor is
* - LoD: [0, 2]
* - tensor: [a0, a1]
* a `lod` is
* - LoD: [0 3 5]
* returns a new LoDTensor
* - [a0 a0 a0 a1 a1]
*/
template <typename T>
LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
const platform::Place& place) {
LoD abs_lod = ToAbsOffset(lod);
const auto& lod_level = lod[level];
size_t num_instances = source.dims()[0];
// new tensor
LoDTensor tensor;
tensor.set_lod(lod);
auto dims = source.dims();
dims[0] = lod_level.back();
tensor.Resize(dims);
tensor.mutable_data<T>(place);
PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
tensor.Slice(elem, elem + 1)
.CopyFrom(source.Slice(ins, ins + 1), platform::CPUPlace(),
platform::CPUDeviceContext());
}
}
return tensor;
}
} // namespace framework
} // namespace paddle
......@@ -92,11 +92,14 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
size_t level = 0;
LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.ShrinkInLevel(level, 0, 1);
EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL);
EXPECT_EQ(new_lod_tensor.NumElements(0), 1UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(2), 5UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>());
ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(2), 5UL);
ASSERT_EQ(new_lod_tensor.dims()[0], 12);
for (int i = 0; i < 12 * 128; i++) {
ASSERT_EQ(new_lod_tensor.data<float>()[i], i);
}
level = 1;
new_lod_tensor = lod_tensor_;
......@@ -104,23 +107,41 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 3UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>());
ASSERT_EQ(new_lod_tensor.dims()[0], 7);
for (int i = 5 * 128; i < 12 * 128; i++) {
ASSERT_EQ(new_lod_tensor.data<float>()[i - 5 * 128], i);
}
LoDTensor t1;
t1.set_lod(lod_tensor_.lod());
t1.ShareDataWith(lod_tensor_);
LoDTensor t2;
t2.set_lod(lod_tensor_.lod());
t2.ShareDataWith(lod_tensor_);
t1.ShrinkInLevel(0, 1, 2);
t2.ShrinkInLevel(0, 0, 1);
EXPECT_NE(t1.data<float>(), t2.data<float>());
EXPECT_NE(t1.data<float>(), lod_tensor_.data<float>());
}
TEST_F(LoDTensorTester, SerializeDeserialize) {
LoDTensor new_lod_tensor = lod_tensor_;
float* src_ptr = lod_tensor_.data<float>();
std::string s = lod_tensor_.SerializeToString();
LoDTensor dst;
dst.DeserializeFromString(s, platform::CPUPlace());
float* dst_ptr = dst.data<float>();
for (int i = 0; i < kLodTensorSize; ++i) {
EXPECT_EQ(dst_ptr[i], src_ptr[i]);
TEST(LodExpand, test) {
LoD lod{{0, 2}};
LoDTensor tensor;
tensor.set_lod(lod);
tensor.Resize({2, 1});
tensor.mutable_data<float>(platform::CPUPlace());
tensor.data<float>()[0] = 0;
tensor.data<float>()[1] = 1;
LoD target;
target.emplace_back(std::vector<size_t>{0, 3, 5});
auto new_tensor = LodExpand<float>(tensor, target, 0UL, platform::CPUPlace());
std::vector<int> result{{0, 0, 0, 1, 1}};
for (size_t i = 0; i < 5; i++) {
ASSERT_EQ(new_tensor.data<float>()[i], result[i]);
}
ASSERT_EQ(dst.NumElements(0), 2UL);
ASSERT_EQ(dst.NumElements(1), 3UL);
ASSERT_EQ(dst.NumElements(2), 8UL);
}
} // namespace framework
......
......@@ -47,31 +47,4 @@ TEST(LoDTensor, LoDInGPU) {
for (size_t i = 0; i < src_lod[0].size(); ++i) {
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
}
}
TEST(LoDTensor, SerializeDeserialize) {
paddle::framework::LoDTensor lod_tensor;
paddle::platform::GPUPlace place(0);
paddle::framework::LoD src_lod;
src_lod.push_back(std::vector<size_t>{0, 2, 4, 6, 8, 10, 12, 14});
lod_tensor.Resize({14, 16});
lod_tensor.mutable_data<float>(place);
lod_tensor.set_lod(src_lod);
CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
test<<<1, 8>>>(src_lod[0].data(), src_lod[0].size());
cudaDeviceSynchronize();
std::string s = lod_tensor.SerializeToString();
paddle::framework::LoDTensor dst;
dst.DeserializeFromString(s, place);
paddle::framework::LoD dst_lod = dst.lod();
for (size_t i = 0; i < dst_lod[0].size(); ++i) {
CHECK_EQ(src_lod[0].data()[i], dst_lod[0].data()[i] * 2);
}
}
}
\ No newline at end of file
......@@ -132,6 +132,8 @@ class Tensor {
std::type_index type() const { return holder_->type(); }
size_t memory_size() const;
private:
inline void check_memory_size() const;
......
......@@ -20,6 +20,8 @@
#include <algorithm>
#include <limits>
#include "paddle/framework/eigen.h"
namespace paddle {
namespace framework {
......@@ -104,10 +106,10 @@ void TensorArray::Write(size_t index, const LoDTensor& value) {
values_.resize(index + 1);
}
values_[index].set_lod(value.lod());
values_[index].Resize(value.dims());
values_[index].mutable_data<value_type>(platform::CPUPlace());
values_[index].CopyFrom(value, platform::CPUPlace(),
platform::CPUDeviceContext());
values_[index].mutable_data<value_type>(value.place());
values_[index].CopyFrom(value, value.place(), platform::CPUDeviceContext());
}
void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
......@@ -116,6 +118,7 @@ void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
values_.resize(index + 1);
}
values_[index].set_lod(value.lod());
values_[index].ShareDataWith(value);
}
......@@ -144,6 +147,155 @@ DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level,
return unpacker.meta;
}
LoDTensor TensorArray::LodPack(size_t level) const {
PADDLE_ENFORCE_GT(size(), 0UL, "no time step exists");
// the levels should be no less than 2
LoDTensor merged;
const LoDTensor *pre, *cur;
pre = &Read(0);
for (size_t step = 1; step < size(); step++) {
cur = &Read(step);
PADDLE_ENFORCE_GT(cur->NumLevels(), 0);
PADDLE_ENFORCE_GT(pre->NumLevels(), 0);
PADDLE_ENFORCE_EQ(pre->NumLevels(), cur->NumLevels());
PADDLE_ENFORCE_EQ(pre->NumElements(level), cur->NumElements(level));
merged = LodPackTwo(*pre, *cur, level);
pre = &merged;
}
return merged;
}
/*
* NOTE currently, only the lowest level supports packing.
* The lowest LoD will be changed, while the relative offsets in levels above
* stay unchanged.
*
* previous step : [0] [1] [3]
* current step: [0 1 2] [2 3] []
* packed to
* [0 0] [0 1] [0 2] [1 2] [1 3] [3]
*/
LoDTensor TensorArray::LodPackTwo(const LoDTensor& pre, const LoDTensor& cur,
size_t level) const {
PADDLE_ENFORCE_EQ(pre.NumLevels(), cur.NumLevels());
PADDLE_ENFORCE_EQ(pre.NumLevels(), level + 1,
"Only the lowest LoD level supports pack temporarily.");
// calculate the result tensor's shape first
size_t num_instances = 0;
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
if (num_candidates > 0) {
num_instances += num_candidates * (prefix_size + 1);
} else {
num_instances += prefix_size;
}
}
auto res_dims = pre.dims();
res_dims[0] = num_instances;
LoDTensor result;
result.Resize(res_dims);
result.mutable_data<value_type>(cur.place());
Vector<size_t> last_lod_level;
// copy data
size_t index = 0;
last_lod_level.push_back(index);
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
// slice the prefix Tensor
LoDTensor prefix = pre;
prefix.ShrinkInLevel(level, elem, elem + 1);
LoDTensor candidate = cur;
if (num_candidates > 0) {
candidate.ShrinkInLevel(level, elem, elem + 1);
} else { // just push prefix
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
last_lod_level.push_back(index);
}
for (size_t candi = 0; candi < num_candidates; candi++) {
// TODO(superjom) support GPU
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
// copy candidate record
result.Slice(index, index + 1)
.CopyFrom(candidate.Slice(candi, candi + 1), result.place(),
platform::CPUDeviceContext());
index++;
last_lod_level.push_back(index);
}
}
// update lod
auto lod = cur.lod();
lod.back() = last_lod_level;
result.set_lod(lod);
return result;
}
/*
* source [0 1 2] [3 4] [5 6 7] will be transformd to a list of LoDTensors such
* as
* [0 3 5] [1 4 6] [2 7] with 1-level LoDs:
* - [0 1 2 3]
* - [0 1 2 3]
* - [0 1 1 2], the [1,1) here means the second sequence is empty
*
* NOTE Unpack a LoDTensor in this approach may result in a big LoD.
*/
void TensorArray::LodUnpack(const LoDTensor& source, size_t level) {
PADDLE_ENFORCE_EQ(level, source.NumLevels() - 1,
"only the lowest LoD level supports unpack.");
const size_t non_empty_instances = source.dims()[0];
size_t index = 0;
Vector<size_t> lowest_lod_level;
lowest_lod_level.push_back(index);
for (size_t step = 0; step < non_empty_instances; step++) {
size_t num_instances = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
num_instances++;
index++;
}
lowest_lod_level.push_back(index);
}
// create tensor for this time step
LoDTensor tensor;
auto dims = source.dims();
dims[0] = num_instances;
// set lod
auto lod = source.lod();
lod.back() = lowest_lod_level;
tensor.set_lod(lod);
index = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
// copy this instance
tensor.Slice(index, index + 1)
.CopyFrom(instance.Slice(step, step + 1), tensor.place(),
platform::CPUDeviceContext());
index++;
}
}
Write(step, tensor);
}
}
LoDTensor TensorArray::Stack() const {
LoDTensor result;
if (size() == 0) return result;
......
......@@ -86,6 +86,16 @@ class TensorArray {
*/
DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend);
/*
* Pack an array of LoDTensors to a LoDTensor.
*/
LoDTensor LodPack(size_t level) const;
/*
* Unpack a LoDTensor to an array of LoDTensors.
*/
void LodUnpack(const LoDTensor &source, size_t level);
/*
* Pack the values into a tensor with rank one higher than each tensor in
* values.
......@@ -111,6 +121,9 @@ class TensorArray {
protected:
void Unstack(const LoDTensor &source, bool data_shared) const;
LoDTensor LodPackTwo(const LoDTensor &pre, const LoDTensor &cur,
size_t level) const;
private:
mutable std::vector<LoDTensor> values_;
}; // class TensorArray
......
......@@ -126,5 +126,57 @@ TEST_F(TensorArrayTester, size) {
ASSERT_EQ(ta.size(), static_cast<size_t>(batch_size));
}
TEST(TensorArray, LodPack) {
// three time steps, each step stores a LoDTensors
// - [0] [1]
// - [2 3], [4 5]
// - [6 7] [] [8], [9, 10]
// try to get a LoDTensor with content:
// - [0 2 6]
// - [0 2 7]
// - [0 3]
// - [1 4 8]
// - [1 5 9]
// - [1 5 10]
std::array<LoDTensor, 3> tensors;
tensors[0].Resize(make_ddim({2, 1}));
tensors[1].Resize(make_ddim({4, 1}));
tensors[2].Resize(make_ddim({5, 1}));
int index = 0;
for (auto& t : tensors) {
t.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < t.dims()[0]; i++) {
t.data<int>()[i] = index;
index++;
}
}
std::array<LoD, 3> lods;
std::vector<std::vector<size_t>> levels{
{0, 1, 2}, {0, 2, 4}, {0, 2, 2, 3, 5}};
for (int i = 0; i < 3; i++) {
lods[i].emplace_back(levels[i].begin(), levels[i].end());
}
TensorArray ta;
for (int i = 0; i < 3; i++) {
tensors[i].set_lod(lods[i]);
ta.Write(i, tensors[i]);
}
auto merged = ta.LodPack(0);
std::vector<int> target_tensor_data{{0, 2, 6, // 0
0, 2, 7, // 1
0, 3, // 2
1, 4, 8, // 3
1, 5, 9, // 5
1, 5, 10}};
EXPECT_EQ(merged.dims()[0], (int)target_tensor_data.size());
for (size_t i = 0; i < target_tensor_data.size(); i++) {
EXPECT_EQ(target_tensor_data[i], merged.data<int>()[i]);
}
}
} // namespace framework
} // namespace paddle
......@@ -62,12 +62,16 @@ inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(
holder_->size(), numel() * SizeOfType(type()) + offset_,
holder_->size(), memory_size() + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
}
inline size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : numel() * SizeOfType(type());
}
template <typename T>
inline const T* Tensor::data() const {
check_memory_size();
......
......@@ -46,6 +46,8 @@ class Variable {
std::type_index(typeid(T)) == std::type_index(holder_->Type());
}
void Clear() { holder_.reset(); }
private:
struct Placeholder {
virtual ~Placeholder() {}
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MKLDNNBatchNormLayer.h"
using namespace mkldnn; // NOLINT
typedef memory::format format;
namespace paddle {
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
const real MKLDNNBatchNormLayer::EPS = 1E-5;
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
return false;
}
// first one is input layer
// the other two are created in config_parser.py saving moving mean and var
CHECK_EQ(inputLayers_.size(), 3U);
CHECK_EQ(inputLayers_.size(), parameters_.size());
CHECK_EQ(inputLayers_.size(), size_t(config_.inputs_size()));
const ImageConfig& conf = config_.inputs(0).image_conf();
ic_ = conf.channels();
ih_ = inputLayers_[0]->getOutput().getFrameHeight();
iw_ = inputLayers_[0]->getOutput().getFrameWidth();
if (iw_ == 0 && ih_ == 0) {
iw_ = conf.img_size();
ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
}
oc_ = ic_;
oh_ = ih_;
ow_ = iw_;
if (config_.has_use_global_stats()) {
useGlobalStats_ = config_.use_global_stats();
}
movingAvgFraction_ = config_.moving_average_fraction();
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
<< " --- global stats";
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
initWeight();
movingMean_.reset(new Weight(oc_, 1, parameters_[1], 0));
movingVar_.reset(new Weight(oc_, 1, parameters_[2], 0));
return true;
}
void MKLDNNBatchNormLayer::initWeight() {
weight_.reset(new Weight(1, oc_, parameters_[0]));
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
}
CHECK_EQ(weight_ != nullptr, biases_ != nullptr)
<< "only support have both weight and bias, or neither";
if (weight_ && weight_->getW()) {
CHECK(biases_ && biases_->getW());
valueScaleShift_ = Matrix::create(2, oc_, false, false);
valueScaleShift_->zeroMem();
VectorPtr scale(new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), 0));
VectorPtr shift(
new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), oc_));
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_VALUE);
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_VALUE);
scale->copyFrom(*wgt);
shift->copyFrom(*bias);
wgt->setData(valueScaleShift_->getData());
bias->setData(valueScaleShift_->getData() + oc_);
}
if (weight_ && weight_->getWGrad()) {
CHECK(biases_ && biases_->getWGrad());
gradScaleShift_ = Matrix::create(2, oc_, false, false);
gradScaleShift_->zeroMem();
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_GRADIENT);
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_GRADIENT);
wgt->setData(gradScaleShift_->getData());
bias->setData(gradScaleShift_->getData() + oc_);
}
}
void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
if (hasInitedWgt_) {
return;
}
// prepare mean and var if necessary
if (useGlobalStats_) {
CHECK(mean_);
CHECK(var_);
mean_->copyFrom(*(movingMean_->getW()));
var_->copyFrom(*(movingVar_->getW()));
}
hasInitedWgt_ = true;
}
void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
// calculating and saving moving mean and variance
CHECK_EQ(useGlobalStats_, false);
movingMean_->getW()->add(
*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
// here var is v^2
movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
}
void MKLDNNBatchNormLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw);
oh = ih;
ow = ow;
// ic_ and oc can not be changed
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
<< "Input channel can not be changed";
reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow);
printSizeInfo();
}
void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
// In training phase, it will always calculate mean and var,
// so useGlobalStats must be false.
// In scoring phase, it depends on useGlobalStats choice.
if (passType_ != PASS_TEST && useGlobalStats_ == true) {
LOG(WARNING) << "use_global_stats is invalid setting in training phase";
useGlobalStats_ = false;
}
resetFwdBuffers(in, wgt, out);
resetFwdPD(fwdPD_, in, wgt, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, out);
}
void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
std::shared_ptr<bn_bwd::primitive_desc> pd;
resetBwdBuffers(in, wgt, out);
resetBwdPD(pd, in, wgt, out);
resetBwdPipeline(pipeline, pd, in, wgt, out);
}
void MKLDNNBatchNormLayer::forward(PassType passType) {
MKLDNNLayer::forward(passType);
// calculate and save moving mean and variance
if (passType_ != PASS_TEST) {
calMovingMeanAndVar();
}
}
void MKLDNNBatchNormLayer::updateWeights(const UpdateCallback& callback) {
weight_->getParameterPtr()->incUpdate(callback);
if (biases_ && biases_->getWGrad()) {
biases_->getParameterPtr()->incUpdate(callback);
}
}
void MKLDNNBatchNormLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
resetInValue(in);
memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_};
CHECK(in);
auto outPD =
MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_);
resetOutValue(out, outPD);
if (valueScaleShift_) {
auto pd = MKLDNNMatrix::createPrimitiveDesc({2, oc_}, format::nc, engine_);
resetWithMatrix(wgt, valueScaleShift_, pd);
}
if (passType_ != PASS_TEST || useGlobalStats_) {
auto pd = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
mean_ = MKLDNNMatrix::create(pd);
var_ = MKLDNNMatrix::create(pd);
}
}
void MKLDNNBatchNormLayer::resetFwdPD(
std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in,
MKLDNNMatrixPtr wgt,
MKLDNNMatrixPtr out) {
flags_ = 0u;
prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring
: prop_kind::forward_training;
if (useGlobalStats_) {
flags_ = (flags_ | batch_normalization_flag::use_global_stats);
}
if (wgt) {
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
}
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
// TODO(TJ): use check macro
CHECK(out);
CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc());
if (wgt) {
CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc());
}
if (passType_ != PASS_TEST || useGlobalStats_) {
CHECK(mean_);
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
CHECK(var_);
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
}
}
void MKLDNNBatchNormLayer::resetFwdPipeline(
std::vector<primitive>& pipeline,
std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
if (passType_ == PASS_TEST) {
if (useGlobalStats_) {
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd,
*in,
(const primitive::at)(*mean_),
(const primitive::at)(*var_),
*wgt,
*out)
: new bn_fwd(*pd,
*in,
(const primitive::at)(*mean_),
(const primitive::at)(*var_),
*out));
} else {
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out)
: new bn_fwd(*pd, *in, *out));
}
} else {
CHECK_EQ(useGlobalStats_, false)
<< "useGlobalStats should be false in training";
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out, *mean_, *var_)
: new bn_fwd(*pd, *in, *out, *mean_, *var_));
}
pipeline.push_back(*fwd_);
}
void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
CHECK(inVal_ && outVal_);
resetOutGrad(out, outVal_->getPrimitiveDesc());
resetInGrad(in, inVal_->getPrimitiveDesc());
if (gradScaleShift_) {
CHECK(wgtVal_);
resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc());
}
}
void MKLDNNBatchNormLayer::resetBwdPD(
std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
pd = nullptr;
if (in == nullptr) {
return;
}
CHECK(out);
CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc());
auto md = in->getMemoryDesc();
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
// TODO(TJ): use check macro
CHECK(wgt);
CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc());
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
CHECK(mean_);
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
CHECK(var_);
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
}
void MKLDNNBatchNormLayer::resetBwdPipeline(
std::vector<primitive>& pipeline,
std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) {
if (pd == nullptr) {
return;
}
CHECK(inVal_);
bwdData_.reset(
wgt && wgtVal_
? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt)
: new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in));
pipeline.push_back(*bwdData_);
}
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "MKLDNNLayer.h"
#include "mkldnn.hpp"
namespace paddle {
typedef mkldnn::batch_normalization_forward bn_fwd;
typedef mkldnn::batch_normalization_backward bn_bwd;
/**
* @brief A subclass of MKLDNNLayer BatchNorm layer.
*
* The config file api is mkldnn_batch_norm
*/
class MKLDNNBatchNormLayer : public MKLDNNLayer {
protected:
// save forward primitive_desc, which can be used backward
std::shared_ptr<bn_fwd::primitive_desc> fwdPD_;
// Epsilon value used in the batch normalization formula.
static const real EPS;
// weight and bias in paddle
std::unique_ptr<Weight> weight_;
std::unique_ptr<Weight> biases_;
// mkldnn use a large buffer store both scale and shift
// which are weight and bias in paddle corresponding.
MatrixPtr valueScaleShift_;
MatrixPtr gradScaleShift_;
// Moving average of mean.
std::unique_ptr<Weight> movingMean_;
// Moving average of variance.
std::unique_ptr<Weight> movingVar_;
// if useGlobalStats_ is true, will use the loaded mean and variance.
// otherwise, calculate mean and variance in every mini-batch.
bool useGlobalStats_;
// used in MKLDNN primitive desc
unsigned flags_;
// use to compute moving mean and variance.
real movingAvgFraction_;
// whether the weight has been init
bool hasInitedWgt_;
// local mean and variance
// when useGlobalStats_ they are loaded from moving mean and variance
// when do not useGlobalStats_ they are calculated from this mini-batch
MKLDNNMatrixPtr mean_;
MKLDNNMatrixPtr var_;
public:
explicit MKLDNNBatchNormLayer(const LayerConfig& config)
: MKLDNNLayer(config), useGlobalStats_(true), hasInitedWgt_(false) {}
~MKLDNNBatchNormLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override;
void convertWeightsFromPaddle() override;
protected:
void initWeight();
/**
* cal moving mean and variance.
* moving = moving * AvgFraction + local * (1 - AvgFraction)
*/
void calMovingMeanAndVar();
/**
* Forward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in,
MKLDNNMatrixPtr wgt,
MKLDNNMatrixPtr out);
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<bn_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
void resetBwdPD(std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<bn_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out);
};
} // namespace paddle
......@@ -91,10 +91,16 @@ void MKLDNNTester::setInputImgSize() {
// init randome parameters of ref, and copy to mkldnn
void MKLDNNTester::randomWgtDatas() {
EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size());
const bool isBN = refLayer_->getType() == "batch_norm";
for (size_t i = 0; i < parameters_[REF].size(); ++i) {
const VectorPtr& dnnValue = parameters_[DNN][i]->getBuf(PARAMETER_VALUE);
const VectorPtr& refValue = parameters_[REF][i]->getBuf(PARAMETER_VALUE);
parameters_[REF][i]->randomize();
if (isBN && i == 2) {
// this param is moving average in batch norm, which must larger than 0
real offset = fabs(refValue->getMin()) + 1.0;
refValue->add(offset);
}
dnnValue->copyFrom(*refValue);
VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName();
......@@ -132,8 +138,7 @@ void MKLDNNTester::checkForward() {
void MKLDNNTester::checkBackwardData() {
VLOG(MKLDNN_TESTS) << "Check Backward Data";
// TODO(TJ): uncomment me when batch norm ready
// const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm";
const bool isBN = refLayer_->getType() == "batch_norm";
for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) {
const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad();
const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad();
......@@ -144,11 +149,11 @@ void MKLDNNTester::checkBackwardData() {
double delta = compareMatrix(dnnDiff, refDiff);
EXPECT_LE(fabs(delta), eps_);
// TODO(TJ): uncomment me when batch norm ready
// if (isBN) {
// // the other two inputs in batch norm are for moving mean and var
// break;
// }
if (isBN) {
// the other two inputs in batch norm are for moving mean and var
// do not have grad to compare
break;
}
}
}
......@@ -308,10 +313,14 @@ double MKLDNNTester::compareVector(const VectorPtr& v1, const VectorPtr& v2) {
void MKLDNNTester::runOnce() {
// test forward
randomBotDatas();
dnnLayer_->forward(PASS_TRAIN);
refLayer_->forward(PASS_TRAIN);
dnnLayer_->forward(passType_);
refLayer_->forward(passType_);
checkForward();
if (passType_ == PASS_TEST) {
return;
}
// test backward
// simple updater
UpdateCallback updateCallback = [](Parameter* para) {
......@@ -343,6 +352,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
size_t batchSize,
size_t inputImgH,
size_t inputImgW,
PassType passType,
bool printDetails,
size_t iter,
float epsilon) {
......@@ -361,6 +371,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
ih_ = inputImgH;
iw_ = inputImgW;
passType_ = passType;
log_ = printDetails;
iter_ = iter;
eps_ = epsilon;
......
......@@ -62,12 +62,15 @@ protected:
float eps_;
/// input image size, default 1
size_t ih_, iw_;
/// passType, PASS_TRAIN, PASS_TEST or PASS_GC (Gradient Check pass)
PassType passType_;
public:
explicit MKLDNNTester(size_t iter = 3, float epsilon = 1e-4) {
iter_ = iter;
eps_ = epsilon;
log_ = false;
passType_ = PASS_TRAIN;
}
~MKLDNNTester() {}
......@@ -78,6 +81,7 @@ public:
size_t batchSize,
size_t inputImgH = 1,
size_t inputImgW = 1,
PassType passType = PASS_TRAIN,
bool printDetails = false,
size_t iter = 3,
float epsilon = 1e-4);
......
......@@ -212,6 +212,66 @@ TEST(MKLDNNLayer, PoolLayer) {
testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2});
}
struct testBatchNormDesc {
int bs;
int ic;
int ih, iw;
};
static void getMKLDNNBatchNormConfig(TestConfig& cfg,
const testBatchNormDesc& pm) {
cfg.layerConfig.set_size(pm.ic * pm.ih * pm.iw);
cfg.layerConfig.set_type("mkldnn_batch_norm");
cfg.biasSize = pm.ic;
cfg.inputDefs.push_back(
{INPUT_DATA,
"layer_0",
/* size of input layer= */ size_t(pm.ic * pm.ih * pm.iw),
/* size of weight= */ size_t(pm.ic)});
cfg.inputDefs.push_back(
{INPUT_DATA, "layer_1_moving_mean", 1, size_t(pm.ic)});
cfg.inputDefs.back().isStatic = true;
cfg.inputDefs.push_back({INPUT_DATA, "layer_2_moving_var", 1, size_t(pm.ic)});
cfg.inputDefs.back().isStatic = true;
LayerInputConfig* input = cfg.layerConfig.add_inputs();
// TODO(TJ): uncomment me when refine and support comparing all zeroes vector
// cfg.layerConfig.set_active_type("relu");
cfg.layerConfig.add_inputs();
cfg.layerConfig.add_inputs();
ImageConfig* img_conf = input->mutable_image_conf();
img_conf->set_channels(pm.ic);
img_conf->set_img_size_y(pm.ih);
img_conf->set_img_size(pm.iw);
}
void testBatchNormLayer(const testBatchNormDesc& pm) {
TestConfig dnnConfig;
getMKLDNNBatchNormConfig(dnnConfig, pm);
TestConfig refConfig = dnnConfig;
refConfig.layerConfig.set_type("batch_norm");
// for PASS_TRAIN, use_global_stats always should be false, and batchsize != 1
VLOG(MKLDNN_TESTS) << "check train phase";
dnnConfig.layerConfig.set_use_global_stats(false);
refConfig.layerConfig.set_use_global_stats(false);
MKLDNNTester tester;
tester.run(dnnConfig, refConfig, pm.bs, pm.ih, pm.iw, PASS_TRAIN);
// for PASS_TEST, check use_global_stats true and false, and batchsize 1
VLOG(MKLDNN_TESTS) << "check test phase";
for (auto useGS : {false, true}) {
dnnConfig.layerConfig.set_use_global_stats(useGS);
refConfig.layerConfig.set_use_global_stats(useGS);
MKLDNNTester tester;
for (auto bs : {pm.bs, 1}) {
tester.run(dnnConfig, refConfig, bs, pm.ih, pm.iw, PASS_TEST);
}
}
}
TEST(MKLDNNLayer, BatchNormLayer) {
testBatchNormLayer({4, 10, 6, 6});
testBatchNormLayer({16, 32, 16, 16});
}
struct testActDesc {
int bs, ic, ih, iw;
};
......
......@@ -91,6 +91,11 @@ public:
const MKLDNNMatrixPtr& dst,
bool checkData = true);
void copyFrom(const Matrix& src) {
// TODO(TJ): reorder data if this format is not nchw or x
m_->copyFrom(src);
}
public:
/**
* Reorder this MKLDNNMatrix from other format.
......
......@@ -60,7 +60,7 @@ public:
*/
inline real* get(int row) const {
if (preallocatedBuf_) {
CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize());
CHECK_LE((row)*width_ * sizeof(real), preallocatedBuf_->getSize());
return reinterpret_cast<real*>(preallocatedBuf_->getBuf()) + row * width_;
} else {
CHECK_LE((row + 1) * width_, rowStore_.size());
......
......@@ -54,6 +54,5 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
#endif
} // namespace memory
} // namespace paddle
......@@ -82,7 +82,7 @@ function(op_library TARGET)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
# nccl_op contains several operators
if ("${TARGET}" STREQUAL "nccl_op")
set(pybind_flag 1)
......@@ -132,6 +132,7 @@ set(DEPS_OPS
pool_op
pool_with_index_op
nccl_op
sequence_conv_op
lstm_op)
......@@ -146,6 +147,7 @@ op_library(pool_with_index_op DEPS pooling)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......@@ -164,3 +166,4 @@ cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic
if(WITH_GPU)
nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/batch_norm_op.h"
#include <cfloat>
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
void ExtractNCWHD(const framework::DDim &dims,
const TensorFormat &tensor_format, int *N, int *C, int *H,
int *W, int *D) {
*N = dims[0];
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
: 1;
}
template <typename T>
class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
VLOG(1) << "Setting descriptors.";
std::vector<int> dims;
std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * D * C, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), saved_mean, 0);
functor(ctx.device_context(), saved_variance, 0);
// FIXME(qiao) should not set zero self
functor(ctx.device_context(), mean_out, 0);
functor(ctx.device_context(), variance_out, 0);
auto handle = ctx.cuda_device_context().cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths.
if (is_test) {
// only when test we use input to do computation.
const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance");
// Run inference mode.
PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(est_mean->dims()[0], C);
PADDLE_ENFORCE_EQ(est_var->dims()[0], C);
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardInference(
handle,
// Note: PERSISTENT not implemented for inference
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
} else {
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// initialize them.
double this_factor = 1. - momentum;
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), bias->template data<T>(), this_factor,
mean_out->template mutable_data<T>(ctx.GetPlace()),
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean->template mutable_data<T>(ctx.GetPlace()),
saved_variance->template mutable_data<T>(ctx.GetPlace())));
}
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};
template <typename T>
class BatchNormGradKernel<platform::GPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
ctx.cuda_device_context().cudnn_handle(), mode_,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
x->template data<T>(), data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(),
d_scale->template mutable_data<T>(ctx.GetPlace()),
d_bias->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean_data, saved_var_data));
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(batch_norm,
ops::BatchNormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::GPUPlace, float>);
......@@ -162,6 +162,8 @@ or not. But the output only shares the LoD with input `X`.
namespace ops = paddle::operators;
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>,
ops::CrossEntropyOpKernel<double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<float>);
ops::CrossEntropyGradientOpKernel<float>,
ops::CrossEntropyGradientOpKernel<double>);
......@@ -108,6 +108,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>);
ops::CrossEntropyGradientOpCUDAKernel<float>,
ops::CrossEntropyGradientOpCUDAKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fill_constant_batch_size_like_op.h"
namespace paddle {
namespace operators {
class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Input"),
"Input(Input) of FillConstantBatchSizeLikeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FillConstantBatchSizeLikeOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto dims = framework::make_ddim(shape_int64);
dims[0] = ctx->GetInputDim("Input")[0];
ctx->SetOutputDim("Out", dims);
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
}
};
class FillConstantBatchSizeLikeOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddInput("Input",
"(Tensor) Tensor "
"whose first dimension is used to specify the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddComment(R"DOC(Fill up a variable with specified constant value.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp,
ops::FillConstantBatchSizeLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_constant_batch_size_like_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value");
auto out_eigen = framework::EigenVector<T>::Flatten(*out);
auto place = ctx.GetEigenDevice<Place>();
out_eigen.device(place) = out_eigen.constant(static_cast<T>(value));
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include <fstream>
namespace paddle {
namespace operators {
class LoadOp : public framework::OperatorBase {
public:
LoadOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
filename);
auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name);
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
uint32_t version;
fin.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
fin.read(reinterpret_cast<char *>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(),
std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void *buf;
platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) {
case framework::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case framework::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case framework::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case framework::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:
PADDLE_THROW("DataType %d not supported", desc.data_type());
}
fin.read(static_cast<char *>(buf), tensor->memory_size());
}
{ // read lod
uint64_t lod_level;
fin.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
fin.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
fin.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
}
auto place = dev_ctx.GetPlace();
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
framework::LoDTensor cpu_tensor;
cpu_tensor.ShareDataWith(*tensor);
cpu_tensor.set_lod(tensor->lod());
// reset tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
tensor->CopyFrom(cpu_tensor, place, dev_ctx);
}
}
};
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoadOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The tensor need to be loaded");
AddComment(R"DOC(Load Operator
Load operator will load a tensor variable from disk file.
)DOC");
AddAttr<std::string>("file_path",
"Variable will be loaded from \"file_path\".")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lrn_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class LRNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LRNOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LRNOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MidOut"),
"MidOut(Out) of LRNOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
ctx->SetOutputDim("Out", x_dim);
ctx->SetOutputDim("MidOut", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LRNOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", R"DOC(
(Tensor) The input of LRN operator. It must be a 4D tenor with NCHW format.
)DOC");
AddOutput("Out",
"(Tensor) The output of LRN operator, which is also the 4D "
"tensor with NCHW format.");
AddOutput("MidOut", R"Doc(
(Tensor)Middle result of lrn op.It's computed in forward process
and also used in backward process.
)Doc");
AddAttr<int>("n", R"DOC(
(int, default 5)n is “adjacent” kernel maps at the same spatial position.
)DOC")
.SetDefault(5)
.GreaterThan(0);
AddAttr<T>("k", R"DOC(
(float, default 2.0)k is the bias.
)DOC")
.SetDefault(2.0)
.GreaterThan(0.0);
AddAttr<T>("alpha", R"DOC(
(float, default 0.0001)alpha is the scale number.
)DOC")
.SetDefault(0.0001)
.GreaterThan(0.0);
AddAttr<T>("beta", R"DOC(
(float, default 0.75)beta is the power number.
)DOC")
.SetDefault(0.75)
.GreaterThan(0.0);
AddComment(R"DOC(
Local Response Normalization.
This Function comes from the paper
"ImageNet Classification with Deep Convolutional Neural Networks".
The original formula is:
Input(i, x, y)
Output(i, x, y) = ----------------------------------------------
-- upper
(k + alpha * > (Input(j, x, y))^2) ^ (beta)
-- j = lower
upper is `min(C, c + n/2)`
lower if `max(0, c - n/2)`
Function implementation:
inputs and outpus is NCHW format, while input.shape.ndims() is equal 4.
And the meaning of each dimension(0-3) is respectively batch size,
feature maps, rows and columns.
Input and Output in the above formula is for each map(i) of one image, and
Input(i, x, y), Output(i, x, y) represents an element in an image.
C is the number of feature maps of one image, and n is a hyper-parameters
is configured when Function is initialized. The sum in the denominator
is the sum of the same position in the neighboring maps.
)DOC");
}
};
class LRNOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("MidOut")),
"Input(MidOut@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lrn, ops::LRNOp, ops::LRNOpMaker<float>, lrn_grad, ops::LRNOpGrad);
REGISTER_OP_CPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(lrn_grad,
ops::LRNGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/lrn_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(lrn_grad,
ops::LRNGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class LRNKernel : public framework::OpKernel<T> {
public:
using Tensor = framework::Tensor;
// f(x) = x * ( k + alpha * SUM((x)^2) )^(-beta)
// x represents inputs
// f(x) represents outputs
void Compute(const framework::ExecutionContext& ctx) const override {
// input
const Tensor* x = ctx.Input<Tensor>("X");
auto x_dims = x->dims();
// NCHW
int N = x_dims[0];
int C = x_dims[1];
int H = x_dims[2];
int W = x_dims[3];
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
// MidOut save the intermediate result for backward
Tensor* mid = ctx.Output<Tensor>("MidOut");
mid->mutable_data<T>(ctx.GetPlace());
int n = ctx.Attr<int>("n");
T alpha = ctx.Attr<float>("alpha");
T beta = ctx.Attr<float>("beta");
T k = ctx.Attr<float>("k");
PADDLE_ENFORCE(n > 0, "n should >= 0");
PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0");
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
auto x_v = framework::EigenVector<T>::Flatten(*x);
const int start = -(n - 1) / 2;
const int end = start + n;
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid.device(ctx.GetEigenDevice<Place>()) = e_mid.constant(k);
auto e_x = framework::EigenTensor<T, 4>::From(*x);
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch >= 0 && ch < C) {
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
s.device(ctx.GetEigenDevice<Place>()) += alpha * r.square();
}
}
}
}
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(ctx.GetEigenDevice<Place>()) =
x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
}
};
/**
* \brief Backward calculation for normalization with across maps.
*
* Function implementation:
*
* The implementation of this Function is derived from the
* CrossMapNormalFunc implementation.
*
* InputGrad = OutputGrad * denoms ^ (-beta)
* -- upper
* + > (OutputGrad * OutputValue * (-2 * alpha * beta) / MidOut) * InputValue
* -- lower
*
* The data of inputs/outputs format is the same as the forward interface
* and is NCHW.
*
* The upper and lower is the same as forward. The logic of the sum
* is also the same as forward.
*/
template <typename Place, typename T>
class LRNGradKernel : public framework::OpKernel<T> {
public:
using Tensor = framework::Tensor;
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* out = ctx.Input<Tensor>("Out");
const Tensor* out_g = ctx.Input<Tensor>(framework::GradVarName("Out"));
const Tensor* mid = ctx.Input<Tensor>("MidOut");
auto x_g = ctx.Output<Tensor>(framework::GradVarName("X"));
x_g->mutable_data<T>(ctx.GetPlace());
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
x_g_e.device(ctx.GetEigenDevice<Place>()) = x_g_e.constant(0.0);
auto x_dims = x->dims();
int N = x_dims[0];
int C = x_dims[1];
int H = x_dims[2];
int W = x_dims[3];
int n = ctx.Attr<int>("n");
T alpha = ctx.Attr<T>("alpha");
T beta = ctx.Attr<T>("beta");
T ratio = -2 * alpha * beta;
auto e_x = framework::EigenTensor<T, 4>::From(*x);
auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
auto e_out = framework::EigenTensor<T, 4>::From(*out);
auto e_out_g = framework::EigenTensor<T, 4>::From(*out_g);
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
const int start = -(n - 1) / 2;
const int end = start + n;
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
i_x_g.device(ctx.GetEigenDevice<Place>()) = i_mid.pow(-beta) * i_out_g;
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch < 0 || ch >= C) {
continue;
}
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
i_x_g.device(ctx.GetEigenDevice<Place>()) +=
ratio * c_out_g * c_out * i_x / c_mid;
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -9,6 +9,7 @@ if(WITH_GPU)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
else()
......@@ -18,6 +19,7 @@ else()
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/context_project.h"
namespace paddle {
namespace operators {
namespace math {
template class ContextProjectFunctor<platform::CPUPlace, float>;
template class ContextProjectFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,28 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle.framework;
import "framework.proto";
/**
* This file contains necessary information for model, checkpoint.
* etc.
*/
message LoDInfo { repeated int64 level = 1; }
/**
* Save the LoDTensorDesc information through LoDTensorProto, its data memory
* is copyed to c buffer immediately. See model_format.md for details.
*/
message LoDTensorProto {
optional DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
repeated LoDInfo levels = 3;
optional int32 lod_level = 4 [ default = 0 ];
optional int32 version = 5;
}
#define EIGEN_USE_GPU
#include "paddle/operators/math/context_project.h"
namespace paddle {
namespace operators {
namespace math {
template class ContextProjectFunctor<platform::GPUPlace, float>;
template class ContextProjectFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/math/im2col.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
/*
* \brief Context projection concatenate features in adjacent time steps in
* a sequence. The i-th row of the output is the concatenation of
* context_length rows of the input. The context_length rows are the
* consecutive rows from the i+shift_start row.
* \param in Input data.
* \param Shape The shape of Input data,
* [minibatch, number_of_input_features].
* \param type A float LoDTensor.
*
* \param padding_data Padding data.
* \param Shape The shape of Padding data,
* [up_pad + down_pad, number_of_input_features].
* \param type A float Tensor.
*
* \param col Col data.
* \param Shape The shape of Col data,
* [minibatch, context_length * number_of_input_features].
* \param type A float Tensor.
*
* For a mini-batch of 2 variable lengths sentences, containing 3, and 1
* time-steps:
*
* Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3,
* 4].
* Besides, for the sake of simplicity, we assume M=1 and N=2.
*
* X = [[a1, a2;
* b1, b2;
* c1, c2]
* [d1, d2]]
*
* This is to say that input (X) has 4 words and the dimension of each word
* representation is 2.
*
* - Case1:
* If context_start is -1 and padding_trainable is false, we use zero to pad
* instead of learned weight to pad,
* and the context_lenth is 3, the output (Out) is:
*
* Out =[[0, 0, a1, a2, b1, b2;
* a1, a2, b1, b2, c1, c2;
* b1, b2, c1, c2, 0, 0 ]
* [0, 0, d1, d2, 0, 0 ]]
*
* - Case2:
* If context_start is -1 and padding_trainable is true, we use learned weight
* to pad,
* and the context_lenth is 3, the output (Out) is:
*
* Out = [[w1, w2, a1, a2, b1, b2;
* a1, a2, b1, b2, c1, c2;
* b1, b2, c1, c2, w3, w4]
* [w1, w2, d1, d2, w3, w4]]
*
*/
template <typename Place, typename T>
class ContextProjectFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::LoDTensor& in, framework::Tensor& padding_data,
framework::Tensor& col, bool padding_trainable,
int context_start, int context_length, int context_stride,
int up_pad, int down_pad, bool gradient, bool input_grad,
bool pad_grad) {
auto lod_level_0 = in.lod()[0];
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
col2im_ocf;
int input_row_begin, input_row_end;
int sequence_height, sequence_width;
sequence_width = in.dims()[1];
input_grad = gradient && input_grad;
pad_grad = gradient && pad_grad;
if (!gradient || input_grad) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
input_row_begin = (context_start > 0)
? static_cast<int>(lod_level_0[i]) + context_start
: static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]);
framework::Tensor out_t =
col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
if (input_row_begin < input_row_end) {
framework::Tensor in_t = in.Slice(input_row_begin, input_row_end);
std::vector<int64_t> output_shape(
{sequence_height, 1, 1, context_length,
sequence_width}); // output_height, output_width,
// input_channels, filter_height, filter_width
out_t.Resize(framework::make_ddim(output_shape));
std::vector<int64_t> input_shape(
{1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
if (gradient) {
col2im_ocf(context, in_t, out_t,
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
} else {
im2col_ocf(context, in_t, out_t,
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
}
out_t.Resize({sequence_height, context_length * sequence_width});
}
}
}
if (!gradient || pad_grad) {
if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
framework::Tensor out_t =
col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
// add up trainable data
out_t.Resize({sequence_height * context_length, sequence_width});
if (up_pad > 0) { // add up pad
int padding_rows = std::min(
up_pad, static_cast<int>(lod_level_0[i + 1] - lod_level_0[i]));
for (int k = 0; k < padding_rows; ++k) {
int padding_size =
k + context_length < up_pad ? context_length : up_pad - k;
framework::Tensor out_t_sub = out_t.Slice(
k * context_length, k * context_length + padding_size);
framework::Tensor w_sub = padding_data.Slice(k, k + padding_size);
// in this block, using EigenVector<T>::Flatten is ok too.
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
if (gradient) {
w_sub_e.device(*context.GetEigenDevice<Place>()) =
w_sub_e + out_t_sub_e;
} else {
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
}
}
}
if (down_pad > 0) { // add down pad
int down_pad_begin_row =
std::max(
0, (sequence_height - context_start - context_length) + 1) +
1;
int padding_begin = std::max(0, context_start - sequence_height);
int padding_size =
sequence_height - context_start >= context_length
? 1
: context_length - (sequence_height - context_start);
if (context_start >= sequence_height) padding_size = context_length;
int padding_idx = padding_begin;
for (int t = 0; t + down_pad_begin_row <= sequence_height;
++t, ++padding_size) {
if (context_start >= sequence_height)
padding_size = context_length;
if (padding_size > context_length) {
padding_size = context_length;
padding_idx++;
}
if (padding_begin > 0 || sequence_height == context_start)
padding_idx = padding_begin + t;
framework::Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
framework::Tensor w_sub = padding_data.Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
if (gradient) {
w_sub_e.device(*context.GetEigenDevice<Place>()) =
w_sub_e + out_t_sub_e;
} else {
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
}
}
}
out_t.Resize({sequence_height, context_length * sequence_width});
}
}
}
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -54,6 +54,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
};
template class CrossEntropyFunctor<platform::CPUPlace, float>;
template class CrossEntropyFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -39,11 +39,36 @@ __device__ __forceinline__ T sum_single_warp(T val) {
return val;
}
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ T* GetPointer() { return NULL; }
};
template <>
struct SharedMemory<float> {
__device__ float* GetPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ double* GetPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
SharedMemory<T> d_sum_shared;
T* d_sum = d_sum_shared.GetPointer();
d_sum[tid] = 0;
int cur_idx = tid;
......@@ -102,6 +127,7 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
};
template class CrossEntropyFunctor<platform::GPUPlace, float>;
template class CrossEntropyFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/proximal_adagrad_op.h"
namespace paddle {
namespace operators {
class ProximalAdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
"Input(LearningRate) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("MomentOut"),
"Output(MomentOut) of ProximalAdagradOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad of ProximalAdagrad Op must have same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Moment"),
"Param and Moment of ProximalAdagrad Op must have same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
}
};
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalAdagradOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated.");
AddInput("Moment",
"(Tensor, default Tensor<float>) "
"Moment parameter that has to be updated.");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter.");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1.");
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
AddOutput("MomentOut", "(Tensor) Output updated moment value.");
AddAttr<float>("l1",
"(float, default 0.0) "
"L1 regularization strength.")
.SetDefault(0.0f);
AddAttr<float>("l2",
"(float, default 0.0)"
"L2 regularization strength.")
.SetDefault(0.0f);
AddComment(R"DOC(
Optimizer that implements the proximal adagrad algorithm.
moment = moment + grad * grad
prox_param = param - learning_rate * grad * (1 / sqrt(moment))
param = sign(prox_param) / (1 + learning_rate * l2) *
max { |prox_param| - learning_rate * l1 , 0 }
The paper that proposed Proximal GD:
(http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf)
Here, we use the adagrad learning rate as specified here:
(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(proximal_adagrad, ops::ProximalAdagradOp,
ops::ProximalAdagradOpMaker);
REGISTER_OP_CPU_KERNEL(
proximal_adagrad,
ops::ProximalAdagradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/proximal_adagrad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
proximal_adagrad,
ops::ProximalAdagradOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class ProximalAdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
auto l1 = static_cast<T>(ctx.Attr<float>("l1"));
auto l2 = static_cast<T>(ctx.Attr<float>("l2"));
auto grad = ctx.Input<Tensor>("Grad");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto g = EigenVector<T>::Flatten(*grad);
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto m_out = EigenVector<T>::Flatten(*moment_out);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
m_out.device(place) = m + g * g;
auto prox_param = p - lr.broadcast(grad_dsize) * g / m_out.sqrt();
if (l1 > static_cast<T>(0)) {
p_out.device(place) =
prox_param.sign() *
(((prox_param.abs() - (lr * l1).broadcast(grad_dsize))
.cwiseMax(static_cast<T>(0.0))) /
(static_cast<T>(1.0) + (lr * l2).broadcast(grad_dsize)));
} else {
p_out.device(place) =
prox_param / (static_cast<T>(1.0) + (lr * l2).broadcast(grad_dsize));
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/op_registry.h"
USE_NO_KERNEL_OP(save);
USE_NO_KERNEL_OP(load);
TEST(SaveLoadOp, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
paddle::platform::CPUDeviceContext ctx(place);
auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({10, 10});
paddle::framework::LoD expect_lod;
expect_lod.resize(1);
expect_lod[0].push_back(0);
expect_lod[0].push_back(1);
expect_lod[0].push_back(2);
expect_lod[0].push_back(3);
tensor->set_lod(expect_lod);
int* expect = tensor->mutable_data<int>(place);
for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) {
expect[i] = static_cast<int>(i);
}
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", std::string("tensor.save")});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, ctx);
auto load_var = scope.Var("out_var");
auto target = load_var->GetMutable<paddle::framework::LoDTensor>();
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {"out_var"}}}, attrs);
load_op->Run(scope, ctx);
int* actual = target->data<int>();
for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) {
EXPECT_EQ(expect[i], actual[i]);
}
auto& actual_lod = target->lod();
EXPECT_EQ(expect_lod.size(), actual_lod.size());
for (size_t i = 0; i < expect_lod.size(); ++i) {
for (size_t j = 0; j < expect_lod[i].size(); ++j) {
EXPECT_EQ(expect_lod[i][j], actual_lod[i][j]);
}
}
}
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <fstream>
#include <numeric>
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
// TODO(yuyang18): If the functions below are needed by other files, move them
// to paddle::filesystem namespace.
constexpr char kSEP = '/';
static bool FileExists(const std::string &filepath) {
struct stat buffer;
return (stat(filepath.c_str(), &buffer) == 0);
}
static std::string DirName(const std::string &filepath) {
auto pos = filepath.rfind(kSEP);
if (pos == std::string::npos) {
return "";
}
return filepath.substr(0, pos);
}
static void MkDir(const char *path) {
if (mkdir(path, 0755)) {
PADDLE_ENFORCE_EQ(errno, EEXIST, "%s mkdir failed!", path);
}
}
static void MkDirRecursively(const char *fullpath) {
if (*fullpath == '\0') return; // empty string
if (FileExists(fullpath)) return;
MkDirRecursively(DirName(fullpath).c_str());
MkDir(fullpath);
}
class SaveOp : public framework::OperatorBase {
public:
SaveOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto iname = Input("X");
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
iname);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveOp only support LoDTensor, %s has wrong type", iname);
auto &tensor = var->Get<framework::LoDTensor>();
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
fout.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
fout.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto *data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto &gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext &>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::GPUPlace>(tensor.place()),
reinterpret_cast<const void *>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
fout.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
fout.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size));
}
}
{ // the 4th field, lod information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto lod = tensor.lod();
uint64_t size = lod.size();
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (auto &each : lod) {
size = each.size() * sizeof(framework::LoD::value_type::value_type);
fout.write(reinterpret_cast<const char *>(&size), sizeof(size));
fout.write(reinterpret_cast<const char *>(each.data()),
static_cast<std::streamsize>(size));
}
}
}
};
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SaveOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The tensor need to be saved");
AddComment(R"DOC(Save operator
Save operator will serialize and write a tensor variable to disk file.
)DOC");
AddAttr<bool>("overwrite", "Overwrite the output file if exist")
.SetDefault(true);
AddAttr<std::string>("file_path",
"Variable will be saved to \"file_path\".")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include <fstream>
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::LoDTensor;
inline static std::string VarToFileName(const std::string& folder_path,
const std::string& var_name) {
return folder_path + "/__" + var_name + "__";
}
class SaveOp : public framework::OperatorBase {
public:
SaveOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
const auto& var_names = this->Inputs("X");
for (const auto& name : var_names) {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
}
std::string folder_path = this->Attr<std::string>("folderPath");
PADDLE_ENFORCE(!folder_path.empty(),
"'folderPath' of SaveOp shouldn't be empty.");
VLOG(1) << "Save variables to folder: " << folder_path;
for (const auto& name : var_names) {
std::string file_name = VarToFileName(folder_path, name);
std::ofstream fout(file_name, std::ofstream::out);
PADDLE_ENFORCE(fout.is_open(), "Fail to create file %s.", file_name);
const LoDTensor& tensor = scope.FindVar(name)->Get<LoDTensor>();
std::string bytes = tensor.SerializeToString();
fout << bytes;
fout.close();
}
VLOG(1) << "Compelete saving variables. Items count: " << var_names.size();
}
};
class SaveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SaveOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(tensor), the tensor count can be 1~INT_MAX, tensors names which "
"values will be saved.")
.AsDuplicable();
AddAttr<std::string>("folderPath", "the folderPath for save model.");
AddComment(R"DOC(
Save the input tensors to a binary file based on input tensor names and absolute path.
All the inputs can carry the LoD (Level of Details) information,
or not.
)DOC");
}
};
class RestoreOp : public framework::OperatorBase {
public:
RestoreOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
const auto& var_names = this->Outputs("Out");
for (const auto& name : var_names) {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
}
std::string folder_path = this->Attr<std::string>("folderPath");
PADDLE_ENFORCE(!folder_path.empty(),
"'folderPath' of RestoreOp shouldn't be empty.");
VLOG(1) << "Try loading variables from folder: " << folder_path;
for (const auto& name : var_names) {
std::string file_name = VarToFileName(folder_path, name);
std::ifstream fin(file_name, std::ifstream::in);
PADDLE_ENFORCE(fin.is_open(), "Fail to open file %s.", file_name);
const size_t kBufferSize = 4096; // equal to linux page size
char buffer[kBufferSize];
std::string cache;
while (!fin.eof()) {
fin.read(buffer, kBufferSize);
cache.append(buffer, fin.gcount());
}
LoDTensor* tensor = scope.FindVar(name)->GetMutable<LoDTensor>();
tensor->DeserializeFromString(cache, dev_ctx.GetPlace());
fin.close();
}
VLOG(1) << "Complete loading variables.";
}
};
class RestoreOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RestoreOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out",
"(tensor), the tensor count can be 1~INT_MAX, tensors which "
"values will be restores.")
.AsDuplicable();
AddAttr<std::string>("folderPath", "the folderPath for model file.");
AddAttr<int>("data_type", "output tensor data type")
.SetDefault(framework::DataType::FP32);
AddComment(R"DOC(
Restore the tensors from model file based on absolute path.
All the tensors outputs may carry the LoD (Level of Details) information,
or not.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(save, paddle::operators::SaveOp,
paddle::framework::EmptyGradOpMaker,
paddle::operators::SaveOpMaker);
REGISTER_OPERATOR(restore, paddle::operators::RestoreOp,
paddle::framework::EmptyGradOpMaker,
paddle::operators::RestoreOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_conv_op.h"
namespace paddle {
namespace operators {
class SequenceConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceConvOp should not be null.");
int context_length = ctx->Attrs().Get<int>("context_length");
bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable");
int context_start = ctx->Attrs().Get<int>("context_start");
auto in_dims = ctx->GetInputDim("X");
auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(filter_dims[0] == context_length * in_dims[1],
"Filter's height should be context_length * "
"number_of_input_features .");
if (padding_trainable) {
PADDLE_ENFORCE(
ctx->HasInput("PaddingData"),
"Input(PaddingData) of SequenceConvOp should not be null.");
framework::DDim padding_dim = ctx->GetInputDim("PaddingData");
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int total_pad = up_pad + down_pad;
int input_width = static_cast<int>(in_dims[1]);
if (context_start == 0 && context_length == 1) {
PADDLE_THROW(
"If context_start is 0 and context_length is 1, padding_trainable "
"should be false.");
}
PADDLE_ENFORCE(padding_dim.size() == 2,
"Input(PaddingData) should be 2-D tensor.");
PADDLE_ENFORCE(
padding_dim[0] == total_pad && padding_dim[1] == input_width,
"Input(PaddingData)'s shape is not consistent with 'context_start' "
"and 'context_length'.");
}
in_dims[1] = filter_dims[1];
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", "Out");
}
};
class SequenceConvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input(X) should not be null.");
if (ctx->Attrs().Get<bool>("padding_trainable") &&
ctx->HasOutput(framework::GradVarName("PaddingData"))) {
ctx->SetOutputDim(framework::GradVarName("PaddingData"),
ctx->GetInputDim("PaddingData"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"),
ctx->GetInputDim("Filter"));
}
}
};
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(LoDTensor) the input(X) is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, D), where, T is the "
"total time steps in this mini-batch, D is the input feature size.");
AddInput("PaddingData",
"(Tensor, optional) the input(PaddingData) is an optional "
"parameter, and it is learnable. "
"This is a tensor with shape (N, D), where N is the "
"top_pad + bottom_pad, D is the input feature size. In order to "
"ensure the equal length of sequence before and after "
"convolution, it is necessary to fill the top and bottom of each "
"sequence according to context_length, context_stride and "
"context_start")
.AsDispensable();
AddInput("Filter",
"(Tensor) the input(Filter) is an learnable parameter."
"This is a tensor with shape (N, D), where N is the "
"context_length, D is the output feature size.");
AddOutput(
"Out",
"(LoDTensor) the output(Out) is a LodTensor, which support "
"variable-time length output sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, D), where, T is the "
"total time steps in this mini-batch, D is the output feature size.");
AddAttr<bool>("padding_trainable",
"(bool, default false) the padding data of SequenceConvOp "
"is trainable or not.")
.SetDefault(false);
AddAttr<int>("context_length",
"(int, default 3) the context_length of SequenceConvOp is the "
"height of the convolution kernel.")
.SetDefault(3)
.GreaterThan(0);
AddAttr<int>("context_start",
"(int, default 0) the context_start of SequenceConvOp "
"represents the beginning of the convolution of the number of "
"rows of sequence, which can be negative.")
.SetDefault(0);
AddAttr<int>("context_stride",
"(int, default 1) the context_stride of SequenceConvOp "
"represents the step length of convolution. "
"Currently, SequenceConvOp only supports"
"context_stride=1.")
.SetDefault(1)
.GreaterThan(0);
AddComment(R"DOC(
SequenceConvOp performs convolution operation on features of
context_length time-steps of each instance.
The convolution operation calculates the output based on the input, filter
and strides, paddings parameters. The size of each dimension of the
parameters is checked in the infer-shape. In order to ensure the equal
length of sequence before and after convolution, it is necessary to fill
the top and bottom of each sequence according to context_length,
context_stride and context_start.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
sequence_conv_grad, ops::SequenceConvGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_conv_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/context_project.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
class SequenceConvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
auto filter = *context.Input<Tensor>("Filter");
out->mutable_data<T>(context.GetPlace());
context.ShareLoD("X", "Out");
int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length");
int context_stride = context.Attr<int>("context_stride");
bool padding_trainable = context.Attr<bool>("padding_trainable");
// InferShape by in_lod
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now.");
const Tensor* padding_data = nullptr;
if (padding_trainable) {
padding_data = context.Input<Tensor>("PaddingData");
}
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width;
sequence_width = static_cast<int>(in->dims()[1]);
// Use col_shape in the im2col calculation.
framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
math::SetConstant<Place, T> set_zero;
// Because if padding_trainable is false, padding data should be zeros.
set_zero(context.device_context(), &col, static_cast<T>(0));
paddle::operators::math::ContextProjectFunctor<Place, T>
seq_project_functor;
LoDTensor* input = const_cast<LoDTensor*>(in);
Tensor* pad_data = const_cast<Tensor*>(padding_data);
seq_project_functor(context.device_context(), *input, *pad_data, col,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, false, false, false);
math::matmul<Place, T>(context.device_context(), col, false, filter, false,
static_cast<T>(1.0), out, static_cast<T>(0.0));
}
};
template <typename Place, typename T>
class SequenceConvGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* filter_g = context.Output<Tensor>(framework::GradVarName("Filter"));
auto* padding_data_g =
context.Output<Tensor>(framework::GradVarName("PaddingData"));
auto* in = context.Input<LoDTensor>("X");
auto* filter = context.Input<Tensor>("Filter");
int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length");
int context_stride = context.Attr<int>("context_stride");
bool padding_trainable = context.Attr<bool>("padding_trainable");
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now.");
auto lod_g_level_0 = in->lod()[0];
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width = static_cast<int>(in->dims()[1]);
math::SetConstant<Place, T> set_zero;
// use col_shape in the im2col calculation
framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length};
Tensor col;
if (in_g || filter_g || (padding_trainable && padding_data_g)) {
col.mutable_data<T>(col_shape, context.GetPlace());
// Because if padding_trainable is false, padding data should be zeros.
set_zero(context.device_context(), &col, static_cast<T>(0));
math::matmul<Place, T>(context.device_context(), *out_g, false, *filter,
true, T(1.0), &col, T(1.0));
}
paddle::operators::math::ContextProjectFunctor<Place, T>
seq_project_functor;
if (in_g) {
in_g->mutable_data<T>(context.GetPlace());
in_g->set_lod(in->lod());
set_zero(context.device_context(), in_g, static_cast<T>(0));
seq_project_functor(context.device_context(), *in_g, *padding_data_g, col,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, true, true, false);
}
if (padding_trainable && padding_data_g) {
padding_data_g->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), padding_data_g, static_cast<T>(0));
LoDTensor* input = const_cast<LoDTensor*>(in);
seq_project_functor(context.device_context(), *input, *padding_data_g,
col, padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, true, false, true);
}
if (filter_g) {
filter_g->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), filter_g, static_cast<T>(0));
Tensor filter_grad = *filter_g;
LoDTensor out_grad = *out_g;
const Tensor* padding_data = nullptr;
if (padding_trainable) {
padding_data = context.Input<Tensor>("PaddingData");
}
sequence_width = static_cast<int>(in->dims()[1]);
LoDTensor* input = const_cast<LoDTensor*>(in);
Tensor* pad_data = const_cast<Tensor*>(padding_data);
seq_project_functor(context.device_context(), *input, *pad_data, col,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, false, false,
false);
math::matmul<Place, T>(context.device_context(), col, true, out_grad,
false, T(1.0), &filter_grad, T(1.0));
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/squared_l2_norm_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class SquaredL2NormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
ctx->SetOutputDim("Out", {1});
}
};
class SquaredL2NormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquaredL2NormOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of squared_l2_norm op.");
AddOutput("Out", "(Float) The output of squared_l2_norm op.");
AddComment(R"DOC(
SquaredL2Norm Operator.
Computes the squared L2 norm of a tensor.
Out = sum (X ** 2)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(squared_l2_norm, ops::SquaredL2NormOp, ops::SquaredL2NormOpMaker,
squared_l2_norm_grad, ops::SquaredL2NormGradOp);
REGISTER_OP_CPU_KERNEL(
squared_l2_norm,
ops::SquaredL2NormKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
squared_l2_norm_grad,
ops::SquaredL2NormGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/squared_l2_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
squared_l2_norm,
ops::SquaredL2NormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
squared_l2_norm_grad,
ops::SquaredL2NormGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
// Out = sum(square(X))
template <typename Place, typename T>
class SquaredL2NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
framework::Tensor *Out = context.Output<framework::Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto place = context.GetEigenDevice<Place>();
out.device(place) = x.square().sum();
}
};
// dX = X
template <typename Place, typename T>
class SquaredL2NormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
const framework::Tensor *dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(dOut->numel() == 1,
"Squared L2 Norm Gradient should be scalar");
framework::Tensor *dX =
context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto dout = framework::EigenVector<T>::Flatten(*dOut);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> x_dsize(X->numel());
dx.device(place) = (dout.broadcast(x_dsize) * x) * static_cast<T>(2.0);
}
};
} // namespace operators
} // namespace paddle
......@@ -44,7 +44,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) {
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
if (momentum_ != 0.0) ProtoToTensor(state.momentums(), momentums_);
}
} // namespace optimizer
......
......@@ -15,7 +15,8 @@ template <class T>
class TensorT {
public:
TensorT(size_t size) : height_(1), width_(size) {
data_ptr_ = std::shared_ptr<T>(new T[size], std::default_delete<T[]>());
// new T[size]() initializes all element to zero value.
data_ptr_ = std::shared_ptr<T>(new T[size](), std::default_delete<T[]>());
data_ = data_ptr_.get();
}
......
......@@ -22,6 +22,47 @@ limitations under the License. */
namespace paddle {
namespace platform {
inline const char* cudnnGetErrorString(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
default:
return "Unknown cudnn error number";
}
}
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
#define CUDNN_ENFORCE(condition) \
do { \
cudnnStatus_t status = condition; \
if (status != CUDNN_STATUS_SUCCESS) { \
VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \
PADDLE_THROW("cuDNN call failed"); \
} \
} while (false)
enum class DataLayout {
kNHWC,
kNCHW,
......@@ -40,12 +81,30 @@ template <>
class CudnnDataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
template <>
class CudnnDataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
......
......@@ -83,6 +83,7 @@ extern void* cudnn_dso_handle;
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
......
......@@ -186,6 +186,7 @@ void ParameterClient2::sendParallel(int tid,
parameter->getMat(recvParameterType).get());
CHECK(recvMat);
size_t width = parameter->getConfig().dims(1);
// TODO(wuyi): need add lock here? may also cause resize.
buf = recvMat->getLocalRow(block.begin_pos() / width);
}
/// sparse_id is not useful while receiving data since sparse data
......@@ -265,9 +266,9 @@ void ParameterClient2::prepareSendData(
uint64_t beginDim = 0;
uint64_t endDim = 0;
// FIXME(typhoonzero): let it resize first
prefetchMat->getLocalRow(nLocalBlocks + 1);
sendMat->getLocalRow(nLocalBlocks + 1);
// HACK(typhoonzero): let it resize first
prefetchMat->getLocalRow(nLocalBlocks);
sendMat->getLocalRow(nLocalBlocks);
for (size_t row = 0; row < nLocalBlocks; ++row) {
int64_t blockId = localIndices[row]; // local row -> sparse row
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/utils/PythonUtil.h"
DEFINE_string(model_dir, "", "Directory for separated model files");
DEFINE_string(config_file, "", "Config file for the model");
DEFINE_string(model_file, "", "File for merged model file");
using namespace paddle; // NOLINT
......@@ -28,7 +29,8 @@ using namespace std; // NOLINT
int main(int argc, char** argv) {
initMain(argc, argv);
initPython(argc, argv);
string confFile = TrainerConfigHelper::getConfigNameFromPath(FLAGS_model_dir);
string confFile = FLAGS_config_file;
#ifndef PADDLE_WITH_CUDA
FLAGS_use_gpu = false;
#endif
......
......@@ -43,11 +43,6 @@ void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) {
ParameterUpdater::init(parameters);
for (auto &para : parameters_) {
para->getBuf(PARAMETER_VALUE)->zeroMem();
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
// create parameter server client.
if (useEtcd_) {
parameterClient_ =
......@@ -109,6 +104,8 @@ void NewRemoteParameterUpdater::init(
LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: "
<< trainerConfig_.learning_rate_schedule() << ", set to const";
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
trainerConfig_.learning_rate());
}
// overwrite optimizerConfigV2 for per-parameter(layer) configs
......
......@@ -89,6 +89,36 @@ tmp = img_pool_layer(input=tmp,
padding=1,
pool_type=MaxPooling())
tmp = img_conv_layer(input=tmp,
filter_size=3,
num_filters=32,
padding=1,
shared_biases=True,
act=LinearActivation(),
bias_attr=False)
tmp = batch_norm_layer(input=tmp,
use_global_stats=False,
act=ReluActivation())
c1 = img_conv_layer(input=tmp,
filter_size=1,
num_filters=32,
padding=0,
shared_biases=True,
act=ReluActivation())
c2 = img_conv_layer(input=tmp,
filter_size=3,
num_filters=32,
padding=1,
shared_biases=True,
act=ReluActivation())
tmp = addto_layer(input=[c1, c2],
act=ReluActivation(),
bias_attr=False)
tmp = fc_layer(input=tmp, size=64,
bias_attr=False,
act=TanhActivation())
......
......@@ -38,9 +38,14 @@ tmp = img_pool_layer(input=tmp,
tmp = img_conv_layer(input=tmp,
filter_size=3,
num_filters=64,
num_filters=32,
padding=1,
shared_biases=True,
act=LinearActivation(),
bias_attr=False)
tmp = batch_norm_layer(input=tmp,
use_global_stats=False,
act=ReluActivation())
tmp = img_pool_layer(input=tmp,
......
......@@ -19,7 +19,7 @@ import "ModelConfig.proto";
package paddle;
message OptimizationConfig {
required int32 batch_size = 3;
optional int32 batch_size = 3 [ default = 1 ];
required string algorithm = 4 [ default = "async_sgd" ];
optional int32 num_batches_per_send_parameter = 5 [ default = 1 ];
optional int32 num_batches_per_get_parameter = 6 [ default = 1 ];
......
......@@ -2420,6 +2420,7 @@ class BatchNormLayer(LayerBase):
# If not use is_static, even set learning_rate = 0, decay_rate = 0,
# these paras will change if set average_window in configure.
use_gpu = bool(int(g_command_config_args.get("use_gpu", 0)))
use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0)))
is_shared = True if not use_gpu else False
for i in xrange(2):
inputs.append(
......@@ -2433,11 +2434,17 @@ class BatchNormLayer(LayerBase):
parallel_nn = bool(int(g_command_config_args.get("parallel_nn", 0)))
cudnn_version = int(g_command_config_args.get("cudnn_version", 0))
# Automatically select cudnn_batch_norm for GPU and batch_norm for CPU.
# Also based on cudnn version.
# Automatically select cudnn_batch_norm for GPU, batch_norm for CPU
# and mkldnn_batch_norm for MKLDNN. Also based on cudnn version.
if batch_norm_type == "mkldnn_batch_norm":
config_assert(use_mkldnn, "mkldnn_batch_norm only support MKLDNN")
use_cudnn = use_gpu and batch_norm_type != "batch_norm" and \
not use_mkldnn and batch_norm_type != "mkldnn_batch_norm" and \
((not parallel_nn) or self.config.device > -1)
self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm"
if use_cudnn:
self.layer_type = "cudnn_batch_norm"
else:
self.layer_type = "mkldnn_batch_norm" if use_mkldnn else "batch_norm"
super(BatchNormLayer, self).__init__(
name, self.layer_type, 0, inputs=inputs, **xargs)
......
......@@ -3014,16 +3014,19 @@ def batch_norm_layer(input,
:param input: batch normalization input. Better be linear activation.
Because there is an activation inside batch_normalization.
:type input: LayerOutput
:param batch_norm_type: We have batch_norm and cudnn_batch_norm. batch_norm
supports both CPU and GPU. cudnn_batch_norm requires
cuDNN version greater or equal to v4 (>=v4). But
cudnn_batch_norm is faster and needs less memory
than batch_norm. By default (None), we will
automaticly select cudnn_batch_norm for GPU and
batch_norm for CPU. Otherwise, select batch norm
type based on the specified type. If you use cudnn_batch_norm,
:param batch_norm_type: We have batch_norm, mkldnn_batch_norm and cudnn_batch_norm.
batch_norm supports CPU, MKLDNN and GPU. cudnn_batch_norm
requires cuDNN version greater or equal to v4 (>=v4).
But cudnn_batch_norm is faster and needs less
memory than batch_norm. mkldnn_batch_norm requires
enable use_mkldnn. By default (None), we will
automaticly select cudnn_batch_norm for GPU,
mkldnn_batch_norm for MKLDNN and batch_norm for CPU.
Otherwise, select batch norm type based on the
specified type. If you use cudnn_batch_norm,
we suggested you use latest version, such as v5.1.
:type batch_norm_type: None | string, None or "batch_norm" or "cudnn_batch_norm"
or "mkldnn_batch_norm"
:param act: Activation Type. Better be relu. Because batch
normalization will normalize input near zero.
:type act: BaseActivation
......@@ -3063,6 +3066,7 @@ def batch_norm_layer(input,
else:
num_channels = input.size
assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \
(batch_norm_type == "mkldnn_batch_norm") or \
(batch_norm_type == "cudnn_batch_norm")
l = Layer(
name=name,
......
......@@ -65,7 +65,14 @@ def download(url, module_name, md5sum):
os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1])
if not (os.path.exists(filename) and md5file(filename) == md5sum):
retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError("Cannot download {0} within retry limit {2}".
format(url, retry_limit))
print "Cache file %s not found, downloading %s" % (filename, url)
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
......
......@@ -261,7 +261,7 @@ class Operator(object):
self.desc.set_attr(attr_name, attrs[attr_name])
self.desc.check_attrs()
no_kernel_op_set = {'feed', 'fetch', 'save', 'restore'}
no_kernel_op_set = {'feed', 'fetch', 'save', 'load'}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc)
......
......@@ -4,7 +4,8 @@ import paddle.v2.framework.framework as framework
from paddle.v2.framework.backward import append_backward_ops
__all__ = [
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer'
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer'
]
......@@ -211,13 +212,14 @@ class MomentumOptimizer(Optimizer):
"""
_velocity_acc_str = "velocity"
def __init__(self, learning_rate, momentum):
def __init__(self, learning_rate, momentum, use_nesterov=False):
assert learning_rate is not None
assert momentum is not None
super(MomentumOptimizer, self).__init__()
self.type = "momentum"
self._learning_rate = learning_rate
self._momentum = momentum
self._use_nesterov = bool(use_nesterov)
def _initialize_tensors(self, block):
assert isinstance(block, framework.Block)
......@@ -259,7 +261,8 @@ class MomentumOptimizer(Optimizer):
"ParamOut": param_and_grad[0],
"VelocityOut": velocity_acc
},
attrs={"mu": self._momentum})
attrs={"mu": self._momentum,
"useNesterov": self._use_nesterov})
return momentum_op
......@@ -397,7 +400,7 @@ class AdamOptimizer(Optimizer):
param_and_grad[0])
moment2 = self._get_accumulator(self._moment2_acc_str,
param_and_grad[0])
# create the momentum optimize op
# create the adam optimize op
adam_op = block.append_op(
type=self.type,
inputs={
......@@ -440,3 +443,108 @@ class AdamOptimizer(Optimizer):
attrs={"scale": self._beta2})
return [scale_beta1, scale_beta2]
class AdamaxOptimizer(Optimizer):
"""Implements the Adamax Optimizer
"""
_moment_acc_str = "moment"
_inf_norm_acc_str = "inf_norm"
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8):
assert learning_rate is not None
assert beta1 is not None
assert beta2 is not None
assert epsilon is not None
super(AdamaxOptimizer, self).__init__()
self.type = "adamax"
self._learning_rate = learning_rate
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
def _initialize_tensors(self, block):
assert isinstance(block, framework.Block)
lr_shape = [1]
# create a variable for learning_rate
self._lr = block.create_var(
dtype="float32", shape=lr_shape, lod_level=0)
# create an op to init the learning_rate
# FIXME: Fix when Initialization design has been implemented
# https://github.com/PaddlePaddle/Paddle/pull/4852
block.append_op(
type="fill_constant",
outputs={"Out": self._lr},
attrs={"shape": lr_shape,
"value": self._learning_rate})
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
global_block = block.program.global_block()
# Create beta1 power accumulator tensor
beta_shape = [1]
self._beta1_pow_acc = global_block.create_var(
dtype="float32", shape=beta_shape, lod_level=0)
# Initialize beta1 power accumulator
# FIXME: Fix when Initialization design has been implemented
# https://github.com/PaddlePaddle/Paddle/pull/4852
global_block.append_op(
type="fill_constant",
outputs={"Out": self._beta1_pow_acc},
attrs={"shape": beta_shape,
"value": self._beta1})
# Create accumulator tensors for first moment and infinity norm
for p in parameters:
self._add_accumulator(block, self._moment_acc_str, p, 'float32')
self._add_accumulator(block, self._inf_norm_acc_str, p, 'float32')
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0])
inf_norm = self._get_accumulator(self._inf_norm_acc_str,
param_and_grad[0])
# create the adamax optimize op
adamax_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": self._lr,
"Moment": moment,
"InfNorm": inf_norm,
"Beta1Pow": self._beta1_pow_acc
},
outputs={
"ParamOut": param_and_grad[0],
"MomentOut": moment,
"InfNormOut": inf_norm
},
attrs={
"beta1": self._beta1,
"beta2": self._beta2,
"epsilon": self._epsilon
})
return adamax_op
def _finish_update(self, block):
"""Update Beta1 Power accumulator
"""
assert isinstance(block, framework.Block)
global_block = block.program.global_block()
scale_beta1 = global_block.append_op(
type="scale",
inputs={"X": self._beta1_pow_acc},
outputs={"Out": self._beta1_pow_acc},
attrs={"scale": self._beta1})
return [scale_beta1]
......@@ -8,6 +8,15 @@ from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder
def randomize_probability(batch_size, class_num, dtype='float32'):
prob = np.random.uniform(
0.1, 1.0, size=(batch_size, class_num)).astype(dtype)
prob_sum = prob.sum(axis=1)
for i in xrange(len(prob)):
prob[i] /= prob_sum[i]
return prob
def grad_var_name(var_name):
return var_name + "@GRAD"
......@@ -233,7 +242,7 @@ def append_input_output(block, op_proto, np_list, is_input):
if (var_name not in np_list) and var_proto.dispensable:
continue
assert (var_name in np_list) or (var_proto.dispensable), \
"Missing {} as input".format(var_name)
"Missing {} as input".format(var_name)
if var_proto.duplicable:
assert isinstance(np_list[var_name], list), \
"Duplicable {} should be set as list".format(var_name)
......@@ -379,9 +388,9 @@ class OpTest(unittest.TestCase):
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return ("%s Variable %s max gradient diff %f over limit %f, "
"the first error element is %d") % (
"the first error element is %d, %f, %f") % (
msg_prefix, name, max_diff, max_relative_error,
offset)
offset, a.flatten()[offset], b.flatten()[offset])
self.assertLessEqual(max_diff, max_relative_error, err_msg())
......@@ -389,6 +398,7 @@ class OpTest(unittest.TestCase):
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None):
......@@ -411,6 +421,7 @@ class OpTest(unittest.TestCase):
self.inputs,
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check
]
grad_names = [
......
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, randomize_probability
class TestCrossEntropyOp1(OpTest):
......@@ -12,12 +12,12 @@ class TestCrossEntropyOp1(OpTest):
batch_size = 30
class_num = 10
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
X = randomize_probability(batch_size, class_num, dtype='float64')
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float32")
dtype="float64")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
......@@ -27,7 +27,7 @@ class TestCrossEntropyOp1(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y")
self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
class TestCrossEntropyOp2(OpTest):
......@@ -39,8 +39,7 @@ class TestCrossEntropyOp2(OpTest):
batch_size = 5
class_num = 37
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
X = randomize_probability(batch_size, class_num)
label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True)
......@@ -55,7 +54,8 @@ class TestCrossEntropyOp2(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05)
self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
class TestCrossEntropyOp3(OpTest):
......@@ -67,8 +67,7 @@ class TestCrossEntropyOp3(OpTest):
batch_size = 5
class_num = 17
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
X = randomize_probability(batch_size, class_num)
label_index = np.random.randint(
0, class_num, (batch_size), dtype="int32")
label = np.zeros(X.shape)
......@@ -88,7 +87,8 @@ class TestCrossEntropyOp3(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05)
self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
if __name__ == "__main__":
......
import unittest
import numpy as np
from op_test import OpTest
class TestFillConstantBatchSizeLikeOp(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("float32")}
self.attrs = {'value': 3.5, 'shape': [-1, 132, 777]}
out = np.random.random((219, 132, 777)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestLRNOp(OpTest):
def get_input(self):
''' TODO(gongweibao): why it's grad diff is so large?
x = np.ndarray(
shape=(self.N, self.C, self.H, self.W), dtype=float, order='C')
for m in range(0, self.N):
for i in range(0, self.C):
for h in range(0, self.H):
for w in range(0, self.W):
x[m][i][h][w] = m * self.C * self.H * self.W + \
i * self.H * self.W + \
h * self.W + w + 1
'''
x = np.random.rand(self.N, self.C, self.H, self.W).astype("float32")
return x + 1
def get_out(self):
start = -(self.n - 1) / 2
end = start + self.n
mid = np.empty((self.N, self.C, self.H, self.W), dtype=float)
mid.fill(self.k)
for m in range(0, self.N):
for i in range(0, self.C):
for c in range(start, end + 1):
ch = i + c
if ch < 0 or ch >= self.C:
continue
s = mid[m][i][:][:]
r = self.x[m][ch][:][:]
s += np.square(r) * self.alpha
mid2 = np.power(mid, -self.beta)
return np.multiply(self.x, mid2), mid
def get_attrs(self):
attrs = {
'n': self.n,
'k': self.k,
'alpha': self.alpha,
'beta': self.beta
}
return attrs
def setUp(self):
self.op_type = "lrn"
self.N = 2
self.C = 3
self.H = 5
self.W = 5
self.n = 5
self.k = 2.0
self.alpha = 0.0001
self.beta = 0.75
self.x = self.get_input()
self.out, self.mid_out = self.get_out()
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out, 'MidOut': self.mid_out}
self.attrs = self.get_attrs()
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
......@@ -36,7 +36,7 @@ class TestMomentumOptimizer(unittest.TestCase):
def get_velocity_str(self):
return self._velocity_acc_str
def test_momentum_optimizer(self):
def test_vanilla_momentum_optimizer(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
......@@ -60,6 +60,42 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertFalse(sgd_op.attr('useNesterov'))
# Check accumulators
accumulators = momentum_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 1)
self.assertTrue(momentum_optimizer.get_velocity_str() in accumulators)
velocity_acc = accumulators[momentum_optimizer.get_velocity_str()]
self.assertEqual(len(velocity_acc), 1)
self.assertTrue(mul_x.name in velocity_acc)
def test_nesterov_momentum_optimizer(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
momentum_optimizer = self.MockMomentum(
learning_rate=0.01, momentum=0.2, use_nesterov=True)
params_grads = append_backward_ops(mul_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(params_grads,
mul_out)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertTrue(sgd_op.attr('useNesterov'))
# Check accumulators
accumulators = momentum_optimizer.get_accumulators()
......@@ -160,5 +196,54 @@ class TestAdamOptimizer(unittest.TestCase):
self.assertTrue(mul_x.name in moment2_acc)
class TestAdamaxOptimizer(unittest.TestCase):
class MockAdamax(optimizer.AdamaxOptimizer):
def get_accumulators(self):
return self._accumulators
def get_moment_str(self):
return self._moment_acc_str
def get_inf_norm_str(self):
return self._inf_norm_acc_str
def test_adamax_optimizer(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
adamax_optimizer = self.MockAdamax(
learning_rate=0.01, beta1=0.9, beta2=0.999)
params_grads = append_backward_ops(mul_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adamax_optimizer.get_accumulators()), 0)
opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out)
self.assertEqual(len(opts), 2)
adam_op = opts[0]
self.assertEqual(adam_op.type, "adamax")
# Check accumulators
accumulators = adamax_optimizer.get_accumulators()
self.assertEqual(len(accumulators), 2)
self.assertTrue(adamax_optimizer.get_moment_str() in accumulators)
self.assertTrue(adamax_optimizer.get_inf_norm_str() in accumulators)
moment_acc = accumulators[adamax_optimizer.get_moment_str()]
inf_norm_acc = accumulators[adamax_optimizer.get_inf_norm_str()]
self.assertEqual(len(moment_acc), 1)
self.assertEqual(len(inf_norm_acc), 1)
self.assertTrue(mul_x.name in moment_acc)
self.assertTrue(mul_x.name in inf_norm_acc)
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestProximalAdagradOp(OpTest):
def setUp(self):
self.op_type = "proximal_adagrad"
w = np.random.random((102, 105)).astype("float32")
m = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32")
lr = np.array([0.1]).astype("float32")
l1 = 0.1
l2 = 0.2
self.inputs = {'Param': w, 'Grad': g, 'Moment': m, 'LearningRate': lr}
self.attrs = {'l1': l1, 'l2': l2}
param_out = 0.0
moment_out = m + g * g
prox_param = w - lr * g / np.sqrt(moment_out)
if l1 > 0.0:
x = np.abs(prox_param) - lr * l1
x[x < 0] = 0
param_out = np.sign(prox_param) * (x / (1.0 + lr * l2))
else:
param_out = prox_param / (1.0 + lr * l2)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
此差异已折叠。
import numpy as np
import unittest
from numpy import linalg as LA
from op_test import OpTest
class TestL2LossOp(OpTest):
"""Test squared_l2_norm
"""
def setUp(self):
self.op_type = "squared_l2_norm"
self.max_relative_error = 0.05
X = np.random.uniform(-1, 1, (13, 19)).astype("float32")
X[np.abs(X) < self.max_relative_error] = 0.1
self.inputs = {'X': X}
self.outputs = {'Out': np.square(LA.norm(X))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['X'], 'Out', max_relative_error=self.max_relative_error)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册