diff --git a/benchmark/IntelOptimizedPaddle.md b/benchmark/IntelOptimizedPaddle.md index ab0be77324450521fee02b7bd7ea12fb9eacf86a..16c2390fd31bf1c79f29735fb98180d3f7302eb2 100644 --- a/benchmark/IntelOptimizedPaddle.md +++ b/benchmark/IntelOptimizedPaddle.md @@ -53,6 +53,15 @@ TBD - GoogLeNet +| BatchSize | 64 | 128 | 256 | +|--------------|-------| ------| -------| +| OpenBLAS | 89.52 | 96.97 | 108.25 | +| MKLML | 128.46| 137.89| 158.63 | +| MKL-DNN     | 250.46| 264.83| 269.50 | + +chart on batch size 128 +TBD + ### Laptop TBD ### Desktop diff --git a/doc/design/reader/README.md b/doc/design/reader/README.md index 320dccec3ddc7bfe6042f4e65b2518ea7b1ad24a..2cd4b6225b61cf374458e40afabad7745f61ba71 100644 --- a/doc/design/reader/README.md +++ b/doc/design/reader/README.md @@ -1,25 +1,25 @@ # Python Data Reader Design Doc -At training and testing time, PaddlePaddle programs need to read data. To ease the users' work to write data reading code, we define that +During the training and testing phases, PaddlePaddle programs need to read data. To help the users write code that performs reading input data, we define the following: -- A *reader* is a function that reads data (from file, network, random number generator, etc) and yields data items. -- A *reader creator* is a function that returns a reader function. -- A *reader decorator* is a function, which accepts one or more readers, and returns a reader. -- A *batch reader* is a function that reads data (from *reader*, file, network, random number generator, etc) and yields a batch of data items. +- A *reader*: A function that reads data (from file, network, random number generator, etc) and yields the data items. +- A *reader creator*: A function that returns a reader function. +- A *reader decorator*: A function, which takes in one or more readers, and returns a reader. +- A *batch reader*: A function that reads data (from *reader*, file, network, random number generator, etc) and yields a batch of data items. -and provide function which converts reader to batch reader, frequently used reader creators and reader decorators. +and also provide a function which can convert a reader to a batch reader, frequently used reader creators and reader decorators. ## Data Reader Interface -Indeed, *data reader* doesn't have to be a function that reads and yields data items. It can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`): +*Data reader* doesn't have to be a function that reads and yields data items. It can just be any function without any parameters that creates an iterable (anything can be used in `for x in iterable`) as follows: ``` iterable = data_reader() ``` -Element produced from the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int) +The item produced from the iterable should be a **single** entry of data and **not** a mini batch. The entry of data could be a single item or a tuple of items. Item should be of one of the [supported types](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int etc.) -An example implementation for single item data reader creator: +An example implementation for single item data reader creator is as follows: ```python def reader_creator_random_image(width, height): @@ -29,7 +29,7 @@ def reader_creator_random_image(width, height): return reader ``` -An example implementation for multiple item data reader creator: +An example implementation for multiple item data reader creator is as follows: ```python def reader_creator_random_image_and_label(width, height, label): def reader(): @@ -40,9 +40,10 @@ def reader_creator_random_image_and_label(width, height, label): ## Batch Reader Interface -*batch reader* can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`). The output of the iterable should be a batch (list) of data items. Each item inside the list must be a tuple. +*Batch reader* can be any function without any parameters that creates an iterable (anything can be used in `for x in iterable`). The output of the iterable should be a batch (list) of data items. Each item inside the list should be a tuple. + +Here are some valid outputs: -Here are valid outputs: ```python # a mini batch of three data items. Each data item consist three columns of data, each of which is 1. [(1, 1, 1), @@ -58,20 +59,22 @@ Here are valid outputs: Please note that each item inside the list must be a tuple, below is an invalid output: ```python # wrong, [1,1,1] needs to be inside a tuple: ([1,1,1],). - # Otherwise it's ambiguous whether [1,1,1] means a single column of data [1, 1, 1], - # or three column of datas, each of which is 1. + # Otherwise it is ambiguous whether [1,1,1] means a single column of data [1, 1, 1], + # or three columns of data, each of which is 1. [[1,1,1], [2,2,2], [3,3,3]] ``` -It's easy to convert from reader to batch reader: +It is easy to convert from a reader to a batch reader: + ```python mnist_train = paddle.dataset.mnist.train() mnist_train_batch_reader = paddle.batch(mnist_train, 128) ``` -Also easy to create custom batch reader: +It is also straight forward to create a custom batch reader: + ```python def custom_batch_reader(): while True: @@ -85,7 +88,8 @@ mnist_random_image_batch_reader = custom_batch_reader ## Usage -batch reader, mapping from item(s) read to data layer, batch size and number of total pass will be passed into `paddle.train`: +Following is how we can use the reader with PaddlePaddle: +The batch reader, a mapping from item(s) to data layer, the batch size and the number of total passes will be passed into `paddle.train` as follows: ```python # two data layer is created: @@ -99,13 +103,13 @@ paddle.train(batch_reader, {"image":0, "label":1}, 128, 10, ...) ## Data Reader Decorator -*Data reader decorator* takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax. +The *Data reader decorator* takes in a single reader or multiple data readers and returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` in the syntax. -Since we have a strict interface for data readers (no parameter, return a single data item). Data reader can be used flexiable via data reader decorators. Following are a few examples: +Since we have a strict interface for data readers (no parameters and return a single data item), a data reader can be used in a flexible way using data reader decorators. Following are a few examples: ### Prefetch Data -Since reading data may take time and training can not proceed without data. It is generally a good idea to prefetch data. +Since reading data may take some time and training can not proceed without data, it is generally a good idea to prefetch the data. Use `paddle.reader.buffered` to prefetch data: @@ -117,9 +121,9 @@ buffered_reader = paddle.reader.buffered(paddle.dataset.mnist.train(), 100) ### Compose Multiple Data Readers -For example, we want to use a source of real images (reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661). +For example, if we want to use a source of real images (say reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661). -We can do: +We can do the following : ```python def reader_creator_random_image(width, height): @@ -139,13 +143,13 @@ false_reader = reader_creator_bool(False) reader = paddle.reader.compose(paddle.dataset.mnist.train(), data_reader_creator_random_image(20, 20), true_reader, false_reader) # Skipped 1 because paddle.dataset.mnist.train() produces two items per data entry. -# And we don't care second item at this time. +# And we don't care about the second item at this time. paddle.train(paddle.batch(reader, 128), {"true_image":0, "fake_image": 2, "true_label": 3, "false_label": 4}, ...) ``` ### Shuffle -Given shuffle buffer size `n`, `paddle.reader.shuffle` will return a data reader that buffers `n` data entries and shuffle them before a data entry is read. +Given the shuffle buffer size `n`, `paddle.reader.shuffle` returns a data reader that buffers `n` data entries and shuffles them before a data entry is read. Example: ```python @@ -154,21 +158,21 @@ reader = paddle.reader.shuffle(paddle.dataset.mnist.train(), 512) ## Q & A -### Why reader return only a single entry, but not a mini batch? +### Why does a reader return only a single entry, and not a mini batch? -Always returning a single entry make reusing existing data readers much easier (e.g., if existing reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2). +Returning a single entry makes reusing existing data readers much easier (for example, if an existing reader returns 3 entries instead if a single entry, the training code will be more complicated because it need to handle cases like a batch size 2). -We provide function `paddle.batch` to turn (single entry) reader into batch reader. +We provide a function: `paddle.batch` to turn (a single entry) reader into a batch reader. -### Why do we need batch reader, isn't train take reader and batch_size as arguments sufficient? +### Why do we need a batch reader, isn't is sufficient to give the reader and batch_size as arguments during training ? -In most of the case, train taking reader and batch_size as arguments would be sufficent. However sometimes user want to customize order of data entries inside a mini batch. Or even change batch size dynamically. +In most of the cases, it would be sufficient to give the reader and batch_size as arguments to the train method. However sometimes the user wants to customize the order of data entries inside a mini batch, or even change the batch size dynamically. For these cases using a batch reader is very efficient and helpful. -### Why use a dictionary but not a list to provide mapping? +### Why use a dictionary instead of a list to provide mapping? -We decided to use dictionary (`{"image":0, "label":1}`) instead of list (`["image", "label"]`) is because that user can easily resue item (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or skip item (e.g., using `{"image_a":0, "label":2}`). +Using a dictionary (`{"image":0, "label":1}`) instead of a list (`["image", "label"]`) gives the advantage that the user can easily reuse the items (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or even skip an item (e.g., using `{"image_a":0, "label":2}`). -### How to create custom data reader creator +### How to create a custom data reader creator ? ```python def image_reader_creator(image_path, label_path, n): @@ -192,7 +196,7 @@ paddle.train(paddle.batch(reader, 128), {"image":0, "label":1}, ...) ### How is `paddle.train` implemented -An example implementation of paddle.train could be: +An example implementation of paddle.train is: ```python def train(batch_reader, mapping, batch_size, total_pass): diff --git a/paddle/capi/examples/model_inference/dense/main.c b/paddle/capi/examples/model_inference/dense/main.c index 876af2aa7615c098d225b56ce2ea0b1529a6e3c6..5eeaf7e31fac7c9ed0b9269e74a7e467bde155ef 100644 --- a/paddle/capi/examples/model_inference/dense/main.c +++ b/paddle/capi/examples/model_inference/dense/main.c @@ -1,5 +1,6 @@ #include #include + #include "../common/common.h" #define CONFIG_BIN "./trainer_config.bin" @@ -27,20 +28,19 @@ int main() { CHECK(paddle_arguments_resize(in_args, 1)); // Create input matrix. - paddle_matrix mat = paddle_matrix_create(/* sample_num */ 10, + paddle_matrix mat = paddle_matrix_create(/* sample_num */ 1, /* size */ 784, /* useGPU */ false); srand(time(0)); - std::vector input; - input.resize(784 * 10); + paddle_real* array; + + // Get First row. + CHECK(paddle_matrix_get_row(mat, 0, &array)); - for (int i = 0; i < input.size(); ++i) { - input[i] = rand() / ((float)RAND_MAX); + for (int i = 0; i < 784; ++i) { + array[i] = rand() / ((float)RAND_MAX); } - - // Set value for the input matrix - CHECK(paddle_matrix_set_value(mat, input.data())); CHECK(paddle_arguments_set_value(in_args, 0, mat)); @@ -53,17 +53,18 @@ int main() { CHECK(paddle_arguments_get_value(out_args, 0, prob)); - std::std::vector result; - int height; - int width; + uint64_t height; + uint64_t width; - CHECK(paddle_matrix_get_shape(prob, &height, &width); - result.resize(height * width); - CHECK(paddle_matrix_get_value(prob, result.data())); + CHECK(paddle_matrix_get_shape(prob, &height, &width)); + CHECK(paddle_matrix_get_row(prob, 0, &array)); - printf("Prob: "); + printf("Prob: \n"); for (int i = 0; i < height * width; ++i) { - printf("%.2f ", result[i]); + printf("%.4f ", array[i]); + if ((i + 1) % width == 0) { + printf("\n"); + } } printf("\n"); diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp index bc7d1c83a48aefeb4bc6d3baa32b78aba712e58d..925af31289d0c8ca534a30a16b14bfd2df90b013 100644 --- a/paddle/gserver/layers/BatchNormBaseLayer.cpp +++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp @@ -41,6 +41,7 @@ bool BatchNormBaseLayer::init(const LayerMap& layerMap, useGlobalStats_ = config_.use_global_stats(); } movingAvgFraction_ = config_.moving_average_fraction(); + epsilon_ = config_.epsilon(); weight_.reset(new Weight(1, channels_, parameters_[0])); movingMean_.reset(new Weight(1, channels_, parameters_[1])); diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h index e721d2d267a31cae46407673b8b1281e87055608..2ac3cd9d670d0fcf9c40ad2f117d5a72479663a3 100644 --- a/paddle/gserver/layers/BatchNormBaseLayer.h +++ b/paddle/gserver/layers/BatchNormBaseLayer.h @@ -94,6 +94,8 @@ protected: bool useGlobalStats_; // use to compute moving mean and variance. real movingAvgFraction_; + // Epsilon is a small random noise used in batch normalization for stability. + real epsilon_; }; } // namespace paddle diff --git a/paddle/gserver/layers/BatchNormalizationLayer.cpp b/paddle/gserver/layers/BatchNormalizationLayer.cpp index dacff25e5927daf9c991577a71be86b160228317..25ab5cd927792d18f78bc1fa33eee4029b427cc7 100644 --- a/paddle/gserver/layers/BatchNormalizationLayer.cpp +++ b/paddle/gserver/layers/BatchNormalizationLayer.cpp @@ -22,8 +22,6 @@ namespace paddle { REGISTER_LAYER(batch_norm, BatchNormalizationLayer); -const real BatchNormalizationLayer::EPS = 1E-5; - bool BatchNormalizationLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { /* Initialize the basic parent class */ @@ -53,7 +51,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) { calMovingMeanAndVar(); - savedInvVar_->subScalar(-EPS); + savedInvVar_->subScalar(-epsilon_); savedInvVar_->sqrt2(*savedInvVar_); } @@ -74,7 +72,7 @@ void BatchNormalizationLayer::setMeanAndStd() { savedInvVar_->copyFrom(*(movingVar_->getW())); savedInvVar_->downClip(real(0.0)); - savedInvVar_->subScalar(-EPS); + savedInvVar_->subScalar(-epsilon_); savedInvVar_->sqrt2(*savedInvVar_); } diff --git a/paddle/gserver/layers/BatchNormalizationLayer.h b/paddle/gserver/layers/BatchNormalizationLayer.h index f6115801fc6b341c0718f8851617de43bdeeec09..1fdb5e2070259a14ab6f70957c9cf03f0699f734 100644 --- a/paddle/gserver/layers/BatchNormalizationLayer.h +++ b/paddle/gserver/layers/BatchNormalizationLayer.h @@ -39,9 +39,6 @@ public: void backward(const UpdateCallback& callback = nullptr) override; protected: - /// Epsilon value used in the batch normalization formula. - static const real EPS; - /// Load pre-calculated mean and std. void setMeanAndStd(); diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index 49a9540c0b6e36b59ed786287ff5c4569b69a6a5..8390b55026c895b661cb514714ba92c05a7bf02e 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -21,8 +21,6 @@ namespace paddle { REGISTER_LAYER(cudnn_batch_norm, CudnnBatchNormLayer); -const double CudnnBatchNormLayer::EPS = 1E-5; - bool CudnnBatchNormLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { /* Initialize the basic parent class */ @@ -61,6 +59,9 @@ void CudnnBatchNormLayer::forward(PassType passType) { real* movingMean = movingMean_->getW()->getData(); real* movingVar = movingVar_->getW()->getData(); + // cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON. + eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast(epsilon_)); + if (!useGlobalStats_) { REGISTER_TIMER_INFO("CudnnBatchFwTimer", getName().c_str()); real* savedMean = savedMean_->getData(); @@ -75,7 +76,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { 1.0 - movingAvgFraction_, movingMean, movingVar, - EPS, + eps_, savedMean, savedInvVar); } else { @@ -90,7 +91,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { beta, movingMean, movingVar, - EPS); + eps_); } else { // There is a limitation in cudnn library. // When the batch size is larger than 1024 in cuDNN v5.1, @@ -101,7 +102,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { beta, movingMean, movingVar, - EPS, + eps_, batchSize, channels_, imageH_ * imageD_, @@ -128,6 +129,9 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { real* savedMean = savedMean_->getData(); real* savedInvVar = savedInvVar_->getData(); + // cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON. + eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast(epsilon_)); + auto create = [](MatrixPtr& m, size_t h, size_t w, real** p) { Matrix::resizeOrCreate(m, h, w, false, true); m->zeroMem(); @@ -157,7 +161,7 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { gamma, gammaGrad, betaGrad, - EPS, + eps_, savedMean, savedInvVar); diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.h b/paddle/gserver/layers/CudnnBatchNormLayer.h index 413efd4d3ecd734b343efbcf8328ac0592daddda..1a3f0c0cbf8a1540e77cef70c753c91298728484 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.h +++ b/paddle/gserver/layers/CudnnBatchNormLayer.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "BatchNormBaseLayer.h" #include "Layer.h" #include "paddle/utils/Stat.h" @@ -46,12 +47,9 @@ public: void backward(const UpdateCallback& callback = nullptr) override; protected: - /** - * Epsilon value used in the batch normalization formula. - * Minimum allowed value is CUDNN_BN_MIN_EPSILON defined in cudnn.h. - * Same epsilon value should be used in forward and backward functions. - */ - static const double EPS; + /// Epsilon value used in the batch normalization formula. + /// Same epsilon value should be used in forward and backward functions. + double eps_; /// Input/output tensor descriptor desc hl_tensor_descriptor ioDesc_; diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp index d66c361ae05e4a1089786e4620d2eb2218a8a29c..7faca0f8b7f54fa0a09e8fdab11064c8c26df375 100644 --- a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp @@ -21,8 +21,6 @@ namespace paddle { REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer); -const real MKLDNNBatchNormLayer::EPS = 1E-5; - bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { if (!MKLDNNLayer::init(layerMap, parameterMap)) { @@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap, useGlobalStats_ = config_.use_global_stats(); } movingAvgFraction_ = config_.moving_average_fraction(); + epsilon_ = config_.epsilon(); + VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use") << " --- global stats"; VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_; @@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD( if (wgt) { flags_ = (flags_ | batch_normalization_flag::use_scale_shift); } - auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_); + auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), epsilon_, flags_); pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_)); CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc()); if (wgt) { @@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD( } CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc()); auto md = in->getMemoryDesc(); - auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_); + auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, epsilon_, flags_); pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_)); CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc()); CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc()); diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.h b/paddle/gserver/layers/MKLDNNBatchNormLayer.h index 387c58f02298b0441cc3bbbc4879eed6d892164c..1cf33cb34fa9cd7c9b8487a0a4a0011fb129e311 100644 --- a/paddle/gserver/layers/MKLDNNBatchNormLayer.h +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.h @@ -32,7 +32,8 @@ protected: std::shared_ptr fwdPD_; // Epsilon value used in the batch normalization formula. - static const real EPS; + real epsilon_; + // weight and bias in paddle std::unique_ptr weight_; std::unique_ptr biases_; diff --git a/paddle/operators/ftrl_op.cc b/paddle/operators/ftrl_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb7ae6919623f10a6c4ec98c0e942c1590ac9a7a --- /dev/null +++ b/paddle/operators/ftrl_op.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/ftrl_op.h" + +namespace paddle { +namespace operators { + +class FTRLOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("SquaredAccumulator"), + "Input(SquaredAccumulator) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LinearAccumulator"), + "Input(LinearAccumulator) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of FTRL should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("SquaredAccumOut"), + "Output(SquaredAccumOut) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LinearAccumOut"), + "Output(LinearAccumOut) of FTRL should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), + "Two input of FTRL Op's dimension must be same."); + + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("SquaredAccumOut", param_dim); + ctx->SetOutputDim("LinearAccumOut", param_dim); + } +}; + +class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FTRLOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter value that has to be updated."); + AddInput("SquaredAccumulator", + "(Tensor, default Tensor) " + "Accumulator that accumulates squared gradients."); + AddInput("LinearAccumulator", + "(Tensor, default Tensor) " + "Accumulator that accumulates linear gradients."); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter."); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1."); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value."); + AddOutput("SquaredAccumOut", + "(Tensor) Output accumulated squared" + " gradients."); + AddOutput("LinearAccumOut", + "(Tensor) Output accumulated linear" + " gradients."); + + AddAttr("l1", + "(float, default 0.0) " + "L1 regularization strength.") + .SetDefault(0.0f); + AddAttr("l2", + "(float, default 0.0) " + "L2 regularization strength.") + .SetDefault(0.0f); + AddAttr("lr_power", + "(float, default -0.5f) " + "Learning Rate Power.") + .SetDefault(-0.5f); + AddComment(R"DOC( +FTRL (Follow The Regularized Leader) Operator. + +Optimizer that implements the FTRL algorithm: + +$$ +new\_accum = squared\_accum + grad^2 \\ +if (lr\_power == -0.5) { + linear\_accum += grad - (\surd(new\_accum) - \surd(squared\_accum)) / + (learning\_rate * param) \\ +} else { + linear\_accum += grad - + (new\_accum^{-lr\_power} - accum^{-lr\_power}) / + (learning\_rate * param) \\ +} + +x = (l1 * sign(linear\_accum) - linear\_accum) +if (lr\_power == -0.5) { + y = \frac{\surd(new\_accum)}{learning\_rate} + (2 * l2) \\ + pre\_shrink = \frac{x}{y} \\ + param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\ +} else { + y = \frac{new\_accum^{-lr\_power}}{learning\_rate} + (2 * l2) \\ + pre\_shrink = \frac{x}{y} \\ + param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\ +} +squared\_accum += grad^2; +$$ + +The paper that proposed Follow The Regularized Leader (FTRL): +(https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker); +REGISTER_OP_CPU_KERNEL(ftrl, + ops::FTRLOpKernel); diff --git a/paddle/operators/ftrl_op.cu b/paddle/operators/ftrl_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..97b36dade6f531df49615ae2d44d565eadba7154 --- /dev/null +++ b/paddle/operators/ftrl_op.cu @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/ftrl_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(ftrl, + ops::FTRLOpKernel); diff --git a/paddle/operators/ftrl_op.h b/paddle/operators/ftrl_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b040162f8d1d8998aa13021c10a25fe57135c1e9 --- /dev/null +++ b/paddle/operators/ftrl_op.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class FTRLOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* param_out = ctx.Output("ParamOut"); + auto* sq_accum_out = ctx.Output("SquaredAccumOut"); + auto* lin_accum_out = ctx.Output("LinearAccumOut"); + + param_out->mutable_data(ctx.GetPlace()); + sq_accum_out->mutable_data(ctx.GetPlace()); + lin_accum_out->mutable_data(ctx.GetPlace()); + + auto grad = ctx.Input("Grad"); + + auto l1 = static_cast(ctx.Attr("l1")); + auto l2 = static_cast(ctx.Attr("l2")); + auto lr_power = static_cast(ctx.Attr("lr_power")); + + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto sq_accum = + EigenVector::Flatten(*ctx.Input("SquaredAccumulator")); + auto lin_accum = + EigenVector::Flatten(*ctx.Input("LinearAccumulator")); + auto g = EigenVector::Flatten(*grad); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + + auto p_out = EigenVector::Flatten(*param_out); + auto s_acc_out = EigenVector::Flatten(*sq_accum_out); + auto l_acc_out = EigenVector::Flatten(*lin_accum_out); + auto place = ctx.GetEigenDevice(); + + Eigen::DSizes grad_dsize(grad->numel()); + + auto new_accum = sq_accum + g * g; + // Special case for lr_power = -0.5 + if (lr_power == static_cast(-0.5)) { + l_acc_out.device(place) = + lin_accum + g - + ((new_accum.sqrt() - sq_accum.sqrt()) / lr.broadcast(grad_dsize)) * p; + } else { + l_acc_out.device(place) = + lin_accum + g - + ((new_accum.pow(-lr_power) - sq_accum.pow(-lr_power)) / + lr.broadcast(grad_dsize)) * + p; + } + + auto x = (l_acc_out.constant(l1) * l_acc_out.sign() - l_acc_out); + if (lr_power == static_cast(-0.5)) { + auto y = (new_accum.sqrt() / lr.broadcast(grad_dsize)) + + l_acc_out.constant(static_cast(2) * l2); + auto pre_shrink = x / y; + p_out.device(place) = + (l_acc_out.abs() > l_acc_out.constant(l1)) + .select(pre_shrink, p.constant(static_cast(0))); + } else { + auto y = (new_accum.pow(-lr_power) / lr.broadcast(grad_dsize)) + + l_acc_out.constant(static_cast(2) * l2); + auto pre_shrink = x / y; + p_out.device(place) = + (l_acc_out.abs() > l_acc_out.constant(l1)) + .select(pre_shrink, p.constant(static_cast(0))); + } + + s_acc_out.device(place) = sq_accum + g * g; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 2c2cc6245932d4af56a68d6399ce31f008bf3748..e2f5592248fd0b6166c2d11af02cef7815673def 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -540,6 +540,10 @@ message LayerConfig { // for switch order layer optional ReshapeConfig reshape_conf = 59; + + // for batch normalization layer + // The small constant added to the variance to improve numeric stability. + optional double epsilon = 60 [ default = 0.00001 ]; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 0941f10cf1ef337ac0e0225aea250dcdd8a27614..5ba0e50c6ba0f84a3ea87d5a5199fef23a5b05ea 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2412,6 +2412,7 @@ class BatchNormLayer(LayerBase): bias=True, img3D=False, use_global_stats=True, + epsilon=1e-5, moving_average_fraction=0.9, batch_norm_type=None, mean_var_names=None, @@ -2460,6 +2461,9 @@ class BatchNormLayer(LayerBase): self.config.use_global_stats = use_global_stats if moving_average_fraction is not None: self.config.moving_average_fraction = moving_average_fraction + if epsilon is not None: + assert epsilon >= 1e-5, "epsilon must be no less than 1e-5." + self.config.epsilon = epsilon input_layer = self.get_input_layer(0) image_conf = self.config.inputs[0].image_conf diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 6bd5ce4fe2f70946befb388986dff603bdae0b8e..8e127c9489ca5a4ed190e6d4e12ec4c9b28ad9cf 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3118,6 +3118,7 @@ def batch_norm_layer(input, param_attr=None, layer_attr=None, batch_norm_type=None, + epsilon=1e-5, moving_average_fraction=0.9, use_global_stats=None, mean_var_names=None): @@ -3188,6 +3189,8 @@ def batch_norm_layer(input, will use the mean and variance of the current batch of test data. :type use_global_stats: bool | None. + :param epsilon: The small constant added to the variance to improve numeric stability. + :type epsilon: float. :param moving_average_fraction: Factor used in the moving average computation. :math:`runningMean = newMean*(1-factor) + runningMean*factor` :type moving_average_fraction: float. @@ -3205,6 +3208,7 @@ def batch_norm_layer(input, assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \ (batch_norm_type == "mkldnn_batch_norm") or \ (batch_norm_type == "cudnn_batch_norm") + l = Layer( name=name, img3D=img3D, @@ -3214,6 +3218,7 @@ def batch_norm_layer(input, type=LayerType.BATCH_NORM_LAYER, batch_norm_type=batch_norm_type, bias=ParamAttr.to_bias(bias_attr), + epsilon=epsilon, moving_average_fraction=moving_average_fraction, use_global_stats=use_global_stats, mean_var_names=mean_var_names, diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr index b14121e82cb7d9516c4771fc896b9b3b9e01d1c8..3e0f957648879d4350d662b336c953273bac1378 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr @@ -65,6 +65,7 @@ layers { height: 227 width: 227 depth: 1 + epsilon: 1e-05 } layers { name: "__crmnorm_0__" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr index c7a487a11231cba6182b654108773037bdb0ec35..a18a4652e14c0cfc4dbca87e67d31aa663ee756b 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr @@ -65,6 +65,7 @@ layers { height: 256 width: 256 depth: 1 + epsilon: 1e-05 } layers { name: "__crmnorm_0__" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr index 832ed24a31dd2bedba9a4fce77d7a088d1796fdb..9b69ae4a3b3cbcc7c0c69a2d5b3728e2f0204f33 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr @@ -36,6 +36,7 @@ layers { height: 6 width: 20 depth: 3 + epsilon: 1e-05 } parameters { name: "___batch_norm_0__.w0" diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 4edc96437f8490012cd58526d8f8b23983074048..33a0829ba8d635ebd68b50f3da07da958fb79dcb 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -91,14 +91,14 @@ def set_omp_mkl_env_vars(trainer_count): .read()) return num_sockets * num_cores_per_socket else: - cmds = {"Darwin": "sysctl hw.physicalcpu"} + cmds = {"Darwin": "sysctl -n hw.physicalcpu"} return int(os.popen(cmds.get(platform.system(), "expr 1")).read()) def num_logical_processors(): '''Get the number of logical processors''' cmds = { "Linux": "grep \"processor\" /proc/cpuinfo|sort -u|wc -l", - "Darwin": "sysctl hw.logicalcpu" + "Darwin": "sysctl -n hw.logicalcpu" } return int(os.popen(cmds.get(platform.system(), "expr 1")).read()) diff --git a/python/paddle/v2/fluid/tests/test_ftrl_op.py b/python/paddle/v2/fluid/tests/test_ftrl_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f77ac4659a9b877829f7ae52dd005d9dd11dac07 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_ftrl_op.py @@ -0,0 +1,62 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFTRLOp(OpTest): + def setUp(self): + self.op_type = "ftrl" + w = np.random.random((102, 105)).astype("float32") + g = np.random.random((102, 105)).astype("float32") + sq_accum = np.full((102, 105), 0.1).astype("float32") + linear_accum = np.full((102, 105), 0.1).astype("float32") + lr = np.array([0.01]).astype("float32") + l1 = 0.1 + l2 = 0.2 + lr_power = -0.5 + + self.inputs = { + 'Param': w, + 'SquaredAccumulator': sq_accum, + 'LinearAccumulator': linear_accum, + 'Grad': g, + 'LearningRate': lr + } + self.attrs = { + 'l1': l1, + 'l2': l2, + 'lr_power': lr_power, + 'learning_rate': lr + } + new_accum = sq_accum + g * g + if lr_power == -0.5: + linear_out = linear_accum + g - ( + (np.sqrt(new_accum) - np.sqrt(sq_accum)) / lr) * w + else: + linear_out = linear_accum + g - ((np.power( + new_accum, -lr_power) - np.power(sq_accum, -lr_power)) / lr) * w + + x = (l1 * np.sign(linear_out) - linear_out) + if lr_power == -0.5: + y = (np.sqrt(new_accum) / lr) + (2 * l2) + pre_shrink = x / y + param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0) + else: + y = (np.power(new_accum, -lr_power) / lr) + (2 * l2) + pre_shrink = x / y + param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0) + + sq_accum_out = sq_accum + g * g + + self.outputs = { + 'ParamOut': param_out, + 'SquaredAccumOut': sq_accum_out, + 'LinearAccumOut': linear_out + } + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()