提交 ee4a5d21 编写于 作者: S sweetsky0901

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into my_unpool_max_2d

......@@ -53,6 +53,15 @@ TBD
- GoogLeNet
| BatchSize | 64 | 128 | 256 |
|--------------|-------| ------| -------|
| OpenBLAS | 89.52 | 96.97 | 108.25 |
| MKLML | 128.46| 137.89| 158.63 |
| MKL-DNN     | 250.46| 264.83| 269.50 |
chart on batch size 128
TBD
### Laptop
TBD
### Desktop
......
# Python Data Reader Design Doc
At training and testing time, PaddlePaddle programs need to read data. To ease the users' work to write data reading code, we define that
During the training and testing phases, PaddlePaddle programs need to read data. To help the users write code that performs reading input data, we define the following:
- A *reader* is a function that reads data (from file, network, random number generator, etc) and yields data items.
- A *reader creator* is a function that returns a reader function.
- A *reader decorator* is a function, which accepts one or more readers, and returns a reader.
- A *batch reader* is a function that reads data (from *reader*, file, network, random number generator, etc) and yields a batch of data items.
- A *reader*: A function that reads data (from file, network, random number generator, etc) and yields the data items.
- A *reader creator*: A function that returns a reader function.
- A *reader decorator*: A function, which takes in one or more readers, and returns a reader.
- A *batch reader*: A function that reads data (from *reader*, file, network, random number generator, etc) and yields a batch of data items.
and provide function which converts reader to batch reader, frequently used reader creators and reader decorators.
and also provide a function which can convert a reader to a batch reader, frequently used reader creators and reader decorators.
## Data Reader Interface
Indeed, *data reader* doesn't have to be a function that reads and yields data items. It can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`):
*Data reader* doesn't have to be a function that reads and yields data items. It can just be any function without any parameters that creates an iterable (anything can be used in `for x in iterable`) as follows:
```
iterable = data_reader()
```
Element produced from the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int)
The item produced from the iterable should be a **single** entry of data and **not** a mini batch. The entry of data could be a single item or a tuple of items. Item should be of one of the [supported types](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int etc.)
An example implementation for single item data reader creator:
An example implementation for single item data reader creator is as follows:
```python
def reader_creator_random_image(width, height):
......@@ -29,7 +29,7 @@ def reader_creator_random_image(width, height):
return reader
```
An example implementation for multiple item data reader creator:
An example implementation for multiple item data reader creator is as follows:
```python
def reader_creator_random_image_and_label(width, height, label):
def reader():
......@@ -40,9 +40,10 @@ def reader_creator_random_image_and_label(width, height, label):
## Batch Reader Interface
*batch reader* can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`). The output of the iterable should be a batch (list) of data items. Each item inside the list must be a tuple.
*Batch reader* can be any function without any parameters that creates an iterable (anything can be used in `for x in iterable`). The output of the iterable should be a batch (list) of data items. Each item inside the list should be a tuple.
Here are some valid outputs:
Here are valid outputs:
```python
# a mini batch of three data items. Each data item consist three columns of data, each of which is 1.
[(1, 1, 1),
......@@ -58,20 +59,22 @@ Here are valid outputs:
Please note that each item inside the list must be a tuple, below is an invalid output:
```python
# wrong, [1,1,1] needs to be inside a tuple: ([1,1,1],).
# Otherwise it's ambiguous whether [1,1,1] means a single column of data [1, 1, 1],
# or three column of datas, each of which is 1.
# Otherwise it is ambiguous whether [1,1,1] means a single column of data [1, 1, 1],
# or three columns of data, each of which is 1.
[[1,1,1],
[2,2,2],
[3,3,3]]
```
It's easy to convert from reader to batch reader:
It is easy to convert from a reader to a batch reader:
```python
mnist_train = paddle.dataset.mnist.train()
mnist_train_batch_reader = paddle.batch(mnist_train, 128)
```
Also easy to create custom batch reader:
It is also straight forward to create a custom batch reader:
```python
def custom_batch_reader():
while True:
......@@ -85,7 +88,8 @@ mnist_random_image_batch_reader = custom_batch_reader
## Usage
batch reader, mapping from item(s) read to data layer, batch size and number of total pass will be passed into `paddle.train`:
Following is how we can use the reader with PaddlePaddle:
The batch reader, a mapping from item(s) to data layer, the batch size and the number of total passes will be passed into `paddle.train` as follows:
```python
# two data layer is created:
......@@ -99,13 +103,13 @@ paddle.train(batch_reader, {"image":0, "label":1}, 128, 10, ...)
## Data Reader Decorator
*Data reader decorator* takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax.
The *Data reader decorator* takes in a single reader or multiple data readers and returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` in the syntax.
Since we have a strict interface for data readers (no parameter, return a single data item). Data reader can be used flexiable via data reader decorators. Following are a few examples:
Since we have a strict interface for data readers (no parameters and return a single data item), a data reader can be used in a flexible way using data reader decorators. Following are a few examples:
### Prefetch Data
Since reading data may take time and training can not proceed without data. It is generally a good idea to prefetch data.
Since reading data may take some time and training can not proceed without data, it is generally a good idea to prefetch the data.
Use `paddle.reader.buffered` to prefetch data:
......@@ -117,9 +121,9 @@ buffered_reader = paddle.reader.buffered(paddle.dataset.mnist.train(), 100)
### Compose Multiple Data Readers
For example, we want to use a source of real images (reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).
For example, if we want to use a source of real images (say reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).
We can do:
We can do the following :
```python
def reader_creator_random_image(width, height):
......@@ -139,13 +143,13 @@ false_reader = reader_creator_bool(False)
reader = paddle.reader.compose(paddle.dataset.mnist.train(), data_reader_creator_random_image(20, 20), true_reader, false_reader)
# Skipped 1 because paddle.dataset.mnist.train() produces two items per data entry.
# And we don't care second item at this time.
# And we don't care about the second item at this time.
paddle.train(paddle.batch(reader, 128), {"true_image":0, "fake_image": 2, "true_label": 3, "false_label": 4}, ...)
```
### Shuffle
Given shuffle buffer size `n`, `paddle.reader.shuffle` will return a data reader that buffers `n` data entries and shuffle them before a data entry is read.
Given the shuffle buffer size `n`, `paddle.reader.shuffle` returns a data reader that buffers `n` data entries and shuffles them before a data entry is read.
Example:
```python
......@@ -154,21 +158,21 @@ reader = paddle.reader.shuffle(paddle.dataset.mnist.train(), 512)
## Q & A
### Why reader return only a single entry, but not a mini batch?
### Why does a reader return only a single entry, and not a mini batch?
Always returning a single entry make reusing existing data readers much easier (e.g., if existing reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2).
Returning a single entry makes reusing existing data readers much easier (for example, if an existing reader returns 3 entries instead if a single entry, the training code will be more complicated because it need to handle cases like a batch size 2).
We provide function `paddle.batch` to turn (single entry) reader into batch reader.
We provide a function: `paddle.batch` to turn (a single entry) reader into a batch reader.
### Why do we need batch reader, isn't train take reader and batch_size as arguments sufficient?
### Why do we need a batch reader, isn't is sufficient to give the reader and batch_size as arguments during training ?
In most of the case, train taking reader and batch_size as arguments would be sufficent. However sometimes user want to customize order of data entries inside a mini batch. Or even change batch size dynamically.
In most of the cases, it would be sufficient to give the reader and batch_size as arguments to the train method. However sometimes the user wants to customize the order of data entries inside a mini batch, or even change the batch size dynamically. For these cases using a batch reader is very efficient and helpful.
### Why use a dictionary but not a list to provide mapping?
### Why use a dictionary instead of a list to provide mapping?
We decided to use dictionary (`{"image":0, "label":1}`) instead of list (`["image", "label"]`) is because that user can easily resue item (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or skip item (e.g., using `{"image_a":0, "label":2}`).
Using a dictionary (`{"image":0, "label":1}`) instead of a list (`["image", "label"]`) gives the advantage that the user can easily reuse the items (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or even skip an item (e.g., using `{"image_a":0, "label":2}`).
### How to create custom data reader creator
### How to create a custom data reader creator ?
```python
def image_reader_creator(image_path, label_path, n):
......@@ -192,7 +196,7 @@ paddle.train(paddle.batch(reader, 128), {"image":0, "label":1}, ...)
### How is `paddle.train` implemented
An example implementation of paddle.train could be:
An example implementation of paddle.train is:
```python
def train(batch_reader, mapping, batch_size, total_pass):
......
#include <paddle/capi.h>
#include <time.h>
#include "../common/common.h"
#define CONFIG_BIN "./trainer_config.bin"
......@@ -27,20 +28,19 @@ int main() {
CHECK(paddle_arguments_resize(in_args, 1));
// Create input matrix.
paddle_matrix mat = paddle_matrix_create(/* sample_num */ 10,
paddle_matrix mat = paddle_matrix_create(/* sample_num */ 1,
/* size */ 784,
/* useGPU */ false);
srand(time(0));
std::vector<paddle_real> input;
input.resize(784 * 10);
paddle_real* array;
// Get First row.
CHECK(paddle_matrix_get_row(mat, 0, &array));
for (int i = 0; i < input.size(); ++i) {
input[i] = rand() / ((float)RAND_MAX);
for (int i = 0; i < 784; ++i) {
array[i] = rand() / ((float)RAND_MAX);
}
// Set value for the input matrix
CHECK(paddle_matrix_set_value(mat, input.data()));
CHECK(paddle_arguments_set_value(in_args, 0, mat));
......@@ -53,17 +53,18 @@ int main() {
CHECK(paddle_arguments_get_value(out_args, 0, prob));
std::std::vector<paddle_real> result;
int height;
int width;
uint64_t height;
uint64_t width;
CHECK(paddle_matrix_get_shape(prob, &height, &width);
result.resize(height * width);
CHECK(paddle_matrix_get_value(prob, result.data()));
CHECK(paddle_matrix_get_shape(prob, &height, &width));
CHECK(paddle_matrix_get_row(prob, 0, &array));
printf("Prob: ");
printf("Prob: \n");
for (int i = 0; i < height * width; ++i) {
printf("%.2f ", result[i]);
printf("%.4f ", array[i]);
if ((i + 1) % width == 0) {
printf("\n");
}
}
printf("\n");
......
......@@ -41,6 +41,7 @@ bool BatchNormBaseLayer::init(const LayerMap& layerMap,
useGlobalStats_ = config_.use_global_stats();
}
movingAvgFraction_ = config_.moving_average_fraction();
epsilon_ = config_.epsilon();
weight_.reset(new Weight(1, channels_, parameters_[0]));
movingMean_.reset(new Weight(1, channels_, parameters_[1]));
......
......@@ -94,6 +94,8 @@ protected:
bool useGlobalStats_;
// use to compute moving mean and variance.
real movingAvgFraction_;
// Epsilon is a small random noise used in batch normalization for stability.
real epsilon_;
};
} // namespace paddle
......@@ -22,8 +22,6 @@ namespace paddle {
REGISTER_LAYER(batch_norm, BatchNormalizationLayer);
const real BatchNormalizationLayer::EPS = 1E-5;
bool BatchNormalizationLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
......@@ -53,7 +51,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
calMovingMeanAndVar();
savedInvVar_->subScalar(-EPS);
savedInvVar_->subScalar(-epsilon_);
savedInvVar_->sqrt2(*savedInvVar_);
}
......@@ -74,7 +72,7 @@ void BatchNormalizationLayer::setMeanAndStd() {
savedInvVar_->copyFrom(*(movingVar_->getW()));
savedInvVar_->downClip(real(0.0));
savedInvVar_->subScalar(-EPS);
savedInvVar_->subScalar(-epsilon_);
savedInvVar_->sqrt2(*savedInvVar_);
}
......
......@@ -39,9 +39,6 @@ public:
void backward(const UpdateCallback& callback = nullptr) override;
protected:
/// Epsilon value used in the batch normalization formula.
static const real EPS;
/// Load pre-calculated mean and std.
void setMeanAndStd();
......
......@@ -21,8 +21,6 @@ namespace paddle {
REGISTER_LAYER(cudnn_batch_norm, CudnnBatchNormLayer);
const double CudnnBatchNormLayer::EPS = 1E-5;
bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
......@@ -61,6 +59,9 @@ void CudnnBatchNormLayer::forward(PassType passType) {
real* movingMean = movingMean_->getW()->getData();
real* movingVar = movingVar_->getW()->getData();
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
if (!useGlobalStats_) {
REGISTER_TIMER_INFO("CudnnBatchFwTimer", getName().c_str());
real* savedMean = savedMean_->getData();
......@@ -75,7 +76,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
1.0 - movingAvgFraction_,
movingMean,
movingVar,
EPS,
eps_,
savedMean,
savedInvVar);
} else {
......@@ -90,7 +91,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
beta,
movingMean,
movingVar,
EPS);
eps_);
} else {
// There is a limitation in cudnn library.
// When the batch size is larger than 1024 in cuDNN v5.1,
......@@ -101,7 +102,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
beta,
movingMean,
movingVar,
EPS,
eps_,
batchSize,
channels_,
imageH_ * imageD_,
......@@ -128,6 +129,9 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
real* savedMean = savedMean_->getData();
real* savedInvVar = savedInvVar_->getData();
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
auto create = [](MatrixPtr& m, size_t h, size_t w, real** p) {
Matrix::resizeOrCreate(m, h, w, false, true);
m->zeroMem();
......@@ -157,7 +161,7 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
gamma,
gammaGrad,
betaGrad,
EPS,
eps_,
savedMean,
savedInvVar);
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <cudnn.h>
#include "BatchNormBaseLayer.h"
#include "Layer.h"
#include "paddle/utils/Stat.h"
......@@ -46,12 +47,9 @@ public:
void backward(const UpdateCallback& callback = nullptr) override;
protected:
/**
* Epsilon value used in the batch normalization formula.
* Minimum allowed value is CUDNN_BN_MIN_EPSILON defined in cudnn.h.
* Same epsilon value should be used in forward and backward functions.
*/
static const double EPS;
/// Epsilon value used in the batch normalization formula.
/// Same epsilon value should be used in forward and backward functions.
double eps_;
/// Input/output tensor descriptor desc
hl_tensor_descriptor ioDesc_;
......
......@@ -21,8 +21,6 @@ namespace paddle {
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
const real MKLDNNBatchNormLayer::EPS = 1E-5;
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
......@@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
useGlobalStats_ = config_.use_global_stats();
}
movingAvgFraction_ = config_.moving_average_fraction();
epsilon_ = config_.epsilon();
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
<< " --- global stats";
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
......@@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD(
if (wgt) {
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
}
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), epsilon_, flags_);
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
if (wgt) {
......@@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD(
}
CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc());
auto md = in->getMemoryDesc();
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, epsilon_, flags_);
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc());
......
......@@ -32,7 +32,8 @@ protected:
std::shared_ptr<bn_fwd::primitive_desc> fwdPD_;
// Epsilon value used in the batch normalization formula.
static const real EPS;
real epsilon_;
// weight and bias in paddle
std::unique_ptr<Weight> weight_;
std::unique_ptr<Weight> biases_;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ftrl_op.h"
namespace paddle {
namespace operators {
class FTRLOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasInput("SquaredAccumulator"),
"Input(SquaredAccumulator) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LinearAccumulator"),
"Input(LinearAccumulator) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("SquaredAccumOut"),
"Output(SquaredAccumOut) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LinearAccumOut"),
"Output(LinearAccumOut) of FTRL should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of FTRL Op's dimension must be same.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("SquaredAccumOut", param_dim);
ctx->SetOutputDim("LinearAccumOut", param_dim);
}
};
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FTRLOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
AddInput("SquaredAccumulator",
"(Tensor, default Tensor<float>) "
"Accumulator that accumulates squared gradients.");
AddInput("LinearAccumulator",
"(Tensor, default Tensor<float>) "
"Accumulator that accumulates linear gradients.");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter.");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1.");
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
AddOutput("SquaredAccumOut",
"(Tensor) Output accumulated squared"
" gradients.");
AddOutput("LinearAccumOut",
"(Tensor) Output accumulated linear"
" gradients.");
AddAttr<float>("l1",
"(float, default 0.0) "
"L1 regularization strength.")
.SetDefault(0.0f);
AddAttr<float>("l2",
"(float, default 0.0) "
"L2 regularization strength.")
.SetDefault(0.0f);
AddAttr<float>("lr_power",
"(float, default -0.5f) "
"Learning Rate Power.")
.SetDefault(-0.5f);
AddComment(R"DOC(
FTRL (Follow The Regularized Leader) Operator.
Optimizer that implements the FTRL algorithm:
$$
new\_accum = squared\_accum + grad^2 \\
if (lr\_power == -0.5) {
linear\_accum += grad - (\surd(new\_accum) - \surd(squared\_accum)) /
(learning\_rate * param) \\
} else {
linear\_accum += grad -
(new\_accum^{-lr\_power} - accum^{-lr\_power}) /
(learning\_rate * param) \\
}
x = (l1 * sign(linear\_accum) - linear\_accum)
if (lr\_power == -0.5) {
y = \frac{\surd(new\_accum)}{learning\_rate} + (2 * l2) \\
pre\_shrink = \frac{x}{y} \\
param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\
} else {
y = \frac{new\_accum^{-lr\_power}}{learning\_rate} + (2 * l2) \\
pre\_shrink = \frac{x}{y} \\
param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\
}
squared\_accum += grad^2;
$$
The paper that proposed Follow The Regularized Leader (FTRL):
(https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker);
REGISTER_OP_CPU_KERNEL(ftrl,
ops::FTRLOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/ftrl_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(ftrl,
ops::FTRLOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class FTRLOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut");
param_out->mutable_data<T>(ctx.GetPlace());
sq_accum_out->mutable_data<T>(ctx.GetPlace());
lin_accum_out->mutable_data<T>(ctx.GetPlace());
auto grad = ctx.Input<Tensor>("Grad");
auto l1 = static_cast<T>(ctx.Attr<float>("l1"));
auto l2 = static_cast<T>(ctx.Attr<float>("l2"));
auto lr_power = static_cast<T>(ctx.Attr<float>("lr_power"));
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto sq_accum =
EigenVector<T>::Flatten(*ctx.Input<Tensor>("SquaredAccumulator"));
auto lin_accum =
EigenVector<T>::Flatten(*ctx.Input<Tensor>("LinearAccumulator"));
auto g = EigenVector<T>::Flatten(*grad);
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto s_acc_out = EigenVector<T>::Flatten(*sq_accum_out);
auto l_acc_out = EigenVector<T>::Flatten(*lin_accum_out);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
auto new_accum = sq_accum + g * g;
// Special case for lr_power = -0.5
if (lr_power == static_cast<T>(-0.5)) {
l_acc_out.device(place) =
lin_accum + g -
((new_accum.sqrt() - sq_accum.sqrt()) / lr.broadcast(grad_dsize)) * p;
} else {
l_acc_out.device(place) =
lin_accum + g -
((new_accum.pow(-lr_power) - sq_accum.pow(-lr_power)) /
lr.broadcast(grad_dsize)) *
p;
}
auto x = (l_acc_out.constant(l1) * l_acc_out.sign() - l_acc_out);
if (lr_power == static_cast<T>(-0.5)) {
auto y = (new_accum.sqrt() / lr.broadcast(grad_dsize)) +
l_acc_out.constant(static_cast<T>(2) * l2);
auto pre_shrink = x / y;
p_out.device(place) =
(l_acc_out.abs() > l_acc_out.constant(l1))
.select(pre_shrink, p.constant(static_cast<T>(0)));
} else {
auto y = (new_accum.pow(-lr_power) / lr.broadcast(grad_dsize)) +
l_acc_out.constant(static_cast<T>(2) * l2);
auto pre_shrink = x / y;
p_out.device(place) =
(l_acc_out.abs() > l_acc_out.constant(l1))
.select(pre_shrink, p.constant(static_cast<T>(0)));
}
s_acc_out.device(place) = sq_accum + g * g;
}
};
} // namespace operators
} // namespace paddle
......@@ -540,6 +540,10 @@ message LayerConfig {
// for switch order layer
optional ReshapeConfig reshape_conf = 59;
// for batch normalization layer
// The small constant added to the variance to improve numeric stability.
optional double epsilon = 60 [ default = 0.00001 ];
}
message EvaluatorConfig {
......
......@@ -2412,6 +2412,7 @@ class BatchNormLayer(LayerBase):
bias=True,
img3D=False,
use_global_stats=True,
epsilon=1e-5,
moving_average_fraction=0.9,
batch_norm_type=None,
mean_var_names=None,
......@@ -2460,6 +2461,9 @@ class BatchNormLayer(LayerBase):
self.config.use_global_stats = use_global_stats
if moving_average_fraction is not None:
self.config.moving_average_fraction = moving_average_fraction
if epsilon is not None:
assert epsilon >= 1e-5, "epsilon must be no less than 1e-5."
self.config.epsilon = epsilon
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf
......
......@@ -3118,6 +3118,7 @@ def batch_norm_layer(input,
param_attr=None,
layer_attr=None,
batch_norm_type=None,
epsilon=1e-5,
moving_average_fraction=0.9,
use_global_stats=None,
mean_var_names=None):
......@@ -3188,6 +3189,8 @@ def batch_norm_layer(input,
will use the mean and variance of the current batch
of test data.
:type use_global_stats: bool | None.
:param epsilon: The small constant added to the variance to improve numeric stability.
:type epsilon: float.
:param moving_average_fraction: Factor used in the moving average computation.
:math:`runningMean = newMean*(1-factor) + runningMean*factor`
:type moving_average_fraction: float.
......@@ -3205,6 +3208,7 @@ def batch_norm_layer(input,
assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \
(batch_norm_type == "mkldnn_batch_norm") or \
(batch_norm_type == "cudnn_batch_norm")
l = Layer(
name=name,
img3D=img3D,
......@@ -3214,6 +3218,7 @@ def batch_norm_layer(input,
type=LayerType.BATCH_NORM_LAYER,
batch_norm_type=batch_norm_type,
bias=ParamAttr.to_bias(bias_attr),
epsilon=epsilon,
moving_average_fraction=moving_average_fraction,
use_global_stats=use_global_stats,
mean_var_names=mean_var_names,
......
......@@ -65,6 +65,7 @@ layers {
height: 227
width: 227
depth: 1
epsilon: 1e-05
}
layers {
name: "__crmnorm_0__"
......
......@@ -65,6 +65,7 @@ layers {
height: 256
width: 256
depth: 1
epsilon: 1e-05
}
layers {
name: "__crmnorm_0__"
......
......@@ -36,6 +36,7 @@ layers {
height: 6
width: 20
depth: 3
epsilon: 1e-05
}
parameters {
name: "___batch_norm_0__.w0"
......
......@@ -91,14 +91,14 @@ def set_omp_mkl_env_vars(trainer_count):
.read())
return num_sockets * num_cores_per_socket
else:
cmds = {"Darwin": "sysctl hw.physicalcpu"}
cmds = {"Darwin": "sysctl -n hw.physicalcpu"}
return int(os.popen(cmds.get(platform.system(), "expr 1")).read())
def num_logical_processors():
'''Get the number of logical processors'''
cmds = {
"Linux": "grep \"processor\" /proc/cpuinfo|sort -u|wc -l",
"Darwin": "sysctl hw.logicalcpu"
"Darwin": "sysctl -n hw.logicalcpu"
}
return int(os.popen(cmds.get(platform.system(), "expr 1")).read())
......
import unittest
import numpy as np
from op_test import OpTest
class TestFTRLOp(OpTest):
def setUp(self):
self.op_type = "ftrl"
w = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32")
sq_accum = np.full((102, 105), 0.1).astype("float32")
linear_accum = np.full((102, 105), 0.1).astype("float32")
lr = np.array([0.01]).astype("float32")
l1 = 0.1
l2 = 0.2
lr_power = -0.5
self.inputs = {
'Param': w,
'SquaredAccumulator': sq_accum,
'LinearAccumulator': linear_accum,
'Grad': g,
'LearningRate': lr
}
self.attrs = {
'l1': l1,
'l2': l2,
'lr_power': lr_power,
'learning_rate': lr
}
new_accum = sq_accum + g * g
if lr_power == -0.5:
linear_out = linear_accum + g - (
(np.sqrt(new_accum) - np.sqrt(sq_accum)) / lr) * w
else:
linear_out = linear_accum + g - ((np.power(
new_accum, -lr_power) - np.power(sq_accum, -lr_power)) / lr) * w
x = (l1 * np.sign(linear_out) - linear_out)
if lr_power == -0.5:
y = (np.sqrt(new_accum) / lr) + (2 * l2)
pre_shrink = x / y
param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0)
else:
y = (np.power(new_accum, -lr_power) / lr) + (2 * l2)
pre_shrink = x / y
param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0)
sq_accum_out = sq_accum + g * g
self.outputs = {
'ParamOut': param_out,
'SquaredAccumOut': sq_accum_out,
'LinearAccumOut': linear_out
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册