diff --git a/benchmark/IntelOptimizedPaddle.md b/benchmark/IntelOptimizedPaddle.md index ab0be77324450521fee02b7bd7ea12fb9eacf86a..16c2390fd31bf1c79f29735fb98180d3f7302eb2 100644 --- a/benchmark/IntelOptimizedPaddle.md +++ b/benchmark/IntelOptimizedPaddle.md @@ -53,6 +53,15 @@ TBD - GoogLeNet +| BatchSize | 64 | 128 | 256 | +|--------------|-------| ------| -------| +| OpenBLAS | 89.52 | 96.97 | 108.25 | +| MKLML | 128.46| 137.89| 158.63 | +| MKL-DNN     | 250.46| 264.83| 269.50 | + +chart on batch size 128 +TBD + ### Laptop TBD ### Desktop diff --git a/doc/design/reader/README.md b/doc/design/reader/README.md index 320dccec3ddc7bfe6042f4e65b2518ea7b1ad24a..2cd4b6225b61cf374458e40afabad7745f61ba71 100644 --- a/doc/design/reader/README.md +++ b/doc/design/reader/README.md @@ -1,25 +1,25 @@ # Python Data Reader Design Doc -At training and testing time, PaddlePaddle programs need to read data. To ease the users' work to write data reading code, we define that +During the training and testing phases, PaddlePaddle programs need to read data. To help the users write code that performs reading input data, we define the following: -- A *reader* is a function that reads data (from file, network, random number generator, etc) and yields data items. -- A *reader creator* is a function that returns a reader function. -- A *reader decorator* is a function, which accepts one or more readers, and returns a reader. -- A *batch reader* is a function that reads data (from *reader*, file, network, random number generator, etc) and yields a batch of data items. +- A *reader*: A function that reads data (from file, network, random number generator, etc) and yields the data items. +- A *reader creator*: A function that returns a reader function. +- A *reader decorator*: A function, which takes in one or more readers, and returns a reader. +- A *batch reader*: A function that reads data (from *reader*, file, network, random number generator, etc) and yields a batch of data items. -and provide function which converts reader to batch reader, frequently used reader creators and reader decorators. +and also provide a function which can convert a reader to a batch reader, frequently used reader creators and reader decorators. ## Data Reader Interface -Indeed, *data reader* doesn't have to be a function that reads and yields data items. It can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`): +*Data reader* doesn't have to be a function that reads and yields data items. It can just be any function without any parameters that creates an iterable (anything can be used in `for x in iterable`) as follows: ``` iterable = data_reader() ``` -Element produced from the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int) +The item produced from the iterable should be a **single** entry of data and **not** a mini batch. The entry of data could be a single item or a tuple of items. Item should be of one of the [supported types](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int etc.) -An example implementation for single item data reader creator: +An example implementation for single item data reader creator is as follows: ```python def reader_creator_random_image(width, height): @@ -29,7 +29,7 @@ def reader_creator_random_image(width, height): return reader ``` -An example implementation for multiple item data reader creator: +An example implementation for multiple item data reader creator is as follows: ```python def reader_creator_random_image_and_label(width, height, label): def reader(): @@ -40,9 +40,10 @@ def reader_creator_random_image_and_label(width, height, label): ## Batch Reader Interface -*batch reader* can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`). The output of the iterable should be a batch (list) of data items. Each item inside the list must be a tuple. +*Batch reader* can be any function without any parameters that creates an iterable (anything can be used in `for x in iterable`). The output of the iterable should be a batch (list) of data items. Each item inside the list should be a tuple. + +Here are some valid outputs: -Here are valid outputs: ```python # a mini batch of three data items. Each data item consist three columns of data, each of which is 1. [(1, 1, 1), @@ -58,20 +59,22 @@ Here are valid outputs: Please note that each item inside the list must be a tuple, below is an invalid output: ```python # wrong, [1,1,1] needs to be inside a tuple: ([1,1,1],). - # Otherwise it's ambiguous whether [1,1,1] means a single column of data [1, 1, 1], - # or three column of datas, each of which is 1. + # Otherwise it is ambiguous whether [1,1,1] means a single column of data [1, 1, 1], + # or three columns of data, each of which is 1. [[1,1,1], [2,2,2], [3,3,3]] ``` -It's easy to convert from reader to batch reader: +It is easy to convert from a reader to a batch reader: + ```python mnist_train = paddle.dataset.mnist.train() mnist_train_batch_reader = paddle.batch(mnist_train, 128) ``` -Also easy to create custom batch reader: +It is also straight forward to create a custom batch reader: + ```python def custom_batch_reader(): while True: @@ -85,7 +88,8 @@ mnist_random_image_batch_reader = custom_batch_reader ## Usage -batch reader, mapping from item(s) read to data layer, batch size and number of total pass will be passed into `paddle.train`: +Following is how we can use the reader with PaddlePaddle: +The batch reader, a mapping from item(s) to data layer, the batch size and the number of total passes will be passed into `paddle.train` as follows: ```python # two data layer is created: @@ -99,13 +103,13 @@ paddle.train(batch_reader, {"image":0, "label":1}, 128, 10, ...) ## Data Reader Decorator -*Data reader decorator* takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax. +The *Data reader decorator* takes in a single reader or multiple data readers and returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` in the syntax. -Since we have a strict interface for data readers (no parameter, return a single data item). Data reader can be used flexiable via data reader decorators. Following are a few examples: +Since we have a strict interface for data readers (no parameters and return a single data item), a data reader can be used in a flexible way using data reader decorators. Following are a few examples: ### Prefetch Data -Since reading data may take time and training can not proceed without data. It is generally a good idea to prefetch data. +Since reading data may take some time and training can not proceed without data, it is generally a good idea to prefetch the data. Use `paddle.reader.buffered` to prefetch data: @@ -117,9 +121,9 @@ buffered_reader = paddle.reader.buffered(paddle.dataset.mnist.train(), 100) ### Compose Multiple Data Readers -For example, we want to use a source of real images (reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661). +For example, if we want to use a source of real images (say reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661). -We can do: +We can do the following : ```python def reader_creator_random_image(width, height): @@ -139,13 +143,13 @@ false_reader = reader_creator_bool(False) reader = paddle.reader.compose(paddle.dataset.mnist.train(), data_reader_creator_random_image(20, 20), true_reader, false_reader) # Skipped 1 because paddle.dataset.mnist.train() produces two items per data entry. -# And we don't care second item at this time. +# And we don't care about the second item at this time. paddle.train(paddle.batch(reader, 128), {"true_image":0, "fake_image": 2, "true_label": 3, "false_label": 4}, ...) ``` ### Shuffle -Given shuffle buffer size `n`, `paddle.reader.shuffle` will return a data reader that buffers `n` data entries and shuffle them before a data entry is read. +Given the shuffle buffer size `n`, `paddle.reader.shuffle` returns a data reader that buffers `n` data entries and shuffles them before a data entry is read. Example: ```python @@ -154,21 +158,21 @@ reader = paddle.reader.shuffle(paddle.dataset.mnist.train(), 512) ## Q & A -### Why reader return only a single entry, but not a mini batch? +### Why does a reader return only a single entry, and not a mini batch? -Always returning a single entry make reusing existing data readers much easier (e.g., if existing reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2). +Returning a single entry makes reusing existing data readers much easier (for example, if an existing reader returns 3 entries instead if a single entry, the training code will be more complicated because it need to handle cases like a batch size 2). -We provide function `paddle.batch` to turn (single entry) reader into batch reader. +We provide a function: `paddle.batch` to turn (a single entry) reader into a batch reader. -### Why do we need batch reader, isn't train take reader and batch_size as arguments sufficient? +### Why do we need a batch reader, isn't is sufficient to give the reader and batch_size as arguments during training ? -In most of the case, train taking reader and batch_size as arguments would be sufficent. However sometimes user want to customize order of data entries inside a mini batch. Or even change batch size dynamically. +In most of the cases, it would be sufficient to give the reader and batch_size as arguments to the train method. However sometimes the user wants to customize the order of data entries inside a mini batch, or even change the batch size dynamically. For these cases using a batch reader is very efficient and helpful. -### Why use a dictionary but not a list to provide mapping? +### Why use a dictionary instead of a list to provide mapping? -We decided to use dictionary (`{"image":0, "label":1}`) instead of list (`["image", "label"]`) is because that user can easily resue item (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or skip item (e.g., using `{"image_a":0, "label":2}`). +Using a dictionary (`{"image":0, "label":1}`) instead of a list (`["image", "label"]`) gives the advantage that the user can easily reuse the items (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or even skip an item (e.g., using `{"image_a":0, "label":2}`). -### How to create custom data reader creator +### How to create a custom data reader creator ? ```python def image_reader_creator(image_path, label_path, n): @@ -192,7 +196,7 @@ paddle.train(paddle.batch(reader, 128), {"image":0, "label":1}, ...) ### How is `paddle.train` implemented -An example implementation of paddle.train could be: +An example implementation of paddle.train is: ```python def train(batch_reader, mapping, batch_size, total_pass): diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index fac5f1d0e25fe205f89fc7eeb9fadfd8431517d5..09bff0a68db82aa723dc08aa83c775910e17c5b8 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -38,7 +38,7 @@ inline bool IsExpand(std::vector& filter_dim, std::vector& dilations) { bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; for (size_t j = 0; j < strides.size(); ++j) { - filter_1 = filter_1 && (static_cast(filter_dim[j]) == 1); + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); strides_1 = strides_1 && (strides[j] == 1); padding_0 = padding_0 && (paddings[j] == 0); dilation_1 = dilation_1 && (dilations[j] == 1); @@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); - - // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} std::vector output_shape_vec(framework::vectorize(output->dims())); - output_shape_vec.erase(output_shape_vec.begin(), - output_shape_vec.begin() + 2); // use col_shape in the im2col calculation // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, // o_h, o_w} - std::vector col_shape_vec; - col_shape_vec.push_back(input->dims()[1] / groups); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), - filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), - output_shape_vec.end()); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } framework::DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * // o_h * o_w) framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + framework::flatten_to_2d(col_shape, data_dim + 1); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; @@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - } else if (filter_shape_vec.size() == 2) { + } else if (data_dim == 2U) { // im2col im2col(context.device_context(), in_slice, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); - } else if (filter_shape_vec.size() == 3) { + } else if (data_dim == 3U) { // vol2col vol2col(context.device_context(), in_slice, dilations, strides, paddings, &col); @@ -206,26 +202,22 @@ class GemmConvGradKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); - - // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} std::vector output_shape_vec( framework::vectorize(output_grad->dims())); - output_shape_vec.erase(output_shape_vec.begin(), - output_shape_vec.begin() + 2); // use col_shape in the im2col calculation // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, // o_h, o_w} - std::vector col_shape_vec; - col_shape_vec.push_back(input->dims()[1] / groups); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), - filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), - output_shape_vec.end()); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } framework::DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation @@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel { // or // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w) framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + framework::flatten_to_2d(col_shape, data_dim + 1); framework::DDim input_shape = framework::slice_ddim( input->dims(), 1, static_cast(input->dims().size())); @@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel { out_grad_slice, false, T(1.0), &col_matrix, T(0.0)); - if (is_expand && filter_shape_vec.size() == 2) { + if (is_expand && data_dim == 2U) { col2im(context.device_context(), col, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &in_grad_slice); - } else if (is_expand && filter_shape_vec.size() == 3) { + } else if (is_expand && data_dim == 3U) { col2vol(context.device_context(), col, dilations, strides, paddings, &in_grad_slice); } @@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - } else if (filter_shape_vec.size() == 2) { + } else if (data_dim == 2U) { im2col(context.device_context(), in_slice, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); - } else if (filter_shape_vec.size() == 3) { + } else if (data_dim == 3U) { vol2col(context.device_context(), in_slice, dilations, strides, paddings, &col); } diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index ab336ad23ce1c180b68d04e4c85b299e301d5376..0fc0735788c499c2d520c0cc689e1ce07ba67ce8 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -68,30 +68,26 @@ class GemmConvTransposeKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // input_shape_vec: {h, w} or {d, h, w} + // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} std::vector input_shape_vec = framework::vectorize(input->dims()); - input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); - - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w} std::vector filter_shape_vec = framework::vectorize(filter.dims()); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); // use col_shape in the im2col and col2im (or vol2col and col2vol) // calculation // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} - std::vector col_shape_vec; - col_shape_vec.push_back(output->dims()[1]); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), - filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), - input_shape_vec.end()); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = output->dims()[1]; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; + } DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) - DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -136,7 +132,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { input_batch, false, static_cast(1.0), &col_matrix, static_cast(0.0)); - if (filter_shape_vec.size() == 2) { + if (data_dim == 2U) { // col2im: col_matrix -> dy // from (c * k_h * k_w, h * w) to (c, o_h, o_w) col2im(context.device_context(), col, @@ -144,7 +140,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &output_batch); - } else if (filter_shape_vec.size() == 3) { + } else if (data_dim == 3U) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) col2vol(context.device_context(), col, dilations, strides, paddings, @@ -176,30 +172,26 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // input_shape_vec: {h, w} or {d, h, w} + // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} std::vector input_shape_vec = framework::vectorize(input->dims()); - input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); - - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w} std::vector filter_shape_vec = framework::vectorize(filter.dims()); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); // use col_shape in the im2col and col2im (or vol2col and col2vol) // calculation // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} - std::vector col_shape_vec; - col_shape_vec.push_back(output_grad->dims()[1]); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), - filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), - input_shape_vec.end()); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = output_grad->dims()[1]; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; + } DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) - DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) DDim output_shape = framework::slice_ddim(output_grad->dims(), 1, @@ -248,7 +240,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { Tensor output_grad_batch = output_grad->Slice(i, i + 1).Resize(output_shape); - if (filter_shape_vec.size() == 2) { + if (data_dim == 2U) { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) im2col(context.device_context(), output_grad_batch, @@ -256,7 +248,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); - } else if (filter_shape_vec.size() == 3) { + } else if (data_dim == 3U) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) vol2col(context.device_context(), output_grad_batch, dilations, diff --git a/paddle/operators/ftrl_op.cc b/paddle/operators/ftrl_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb7ae6919623f10a6c4ec98c0e942c1590ac9a7a --- /dev/null +++ b/paddle/operators/ftrl_op.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/ftrl_op.h" + +namespace paddle { +namespace operators { + +class FTRLOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("SquaredAccumulator"), + "Input(SquaredAccumulator) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LinearAccumulator"), + "Input(LinearAccumulator) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of FTRL should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("SquaredAccumOut"), + "Output(SquaredAccumOut) of FTRL should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LinearAccumOut"), + "Output(LinearAccumOut) of FTRL should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), + "Two input of FTRL Op's dimension must be same."); + + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("SquaredAccumOut", param_dim); + ctx->SetOutputDim("LinearAccumOut", param_dim); + } +}; + +class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FTRLOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter value that has to be updated."); + AddInput("SquaredAccumulator", + "(Tensor, default Tensor) " + "Accumulator that accumulates squared gradients."); + AddInput("LinearAccumulator", + "(Tensor, default Tensor) " + "Accumulator that accumulates linear gradients."); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter."); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1."); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value."); + AddOutput("SquaredAccumOut", + "(Tensor) Output accumulated squared" + " gradients."); + AddOutput("LinearAccumOut", + "(Tensor) Output accumulated linear" + " gradients."); + + AddAttr("l1", + "(float, default 0.0) " + "L1 regularization strength.") + .SetDefault(0.0f); + AddAttr("l2", + "(float, default 0.0) " + "L2 regularization strength.") + .SetDefault(0.0f); + AddAttr("lr_power", + "(float, default -0.5f) " + "Learning Rate Power.") + .SetDefault(-0.5f); + AddComment(R"DOC( +FTRL (Follow The Regularized Leader) Operator. + +Optimizer that implements the FTRL algorithm: + +$$ +new\_accum = squared\_accum + grad^2 \\ +if (lr\_power == -0.5) { + linear\_accum += grad - (\surd(new\_accum) - \surd(squared\_accum)) / + (learning\_rate * param) \\ +} else { + linear\_accum += grad - + (new\_accum^{-lr\_power} - accum^{-lr\_power}) / + (learning\_rate * param) \\ +} + +x = (l1 * sign(linear\_accum) - linear\_accum) +if (lr\_power == -0.5) { + y = \frac{\surd(new\_accum)}{learning\_rate} + (2 * l2) \\ + pre\_shrink = \frac{x}{y} \\ + param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\ +} else { + y = \frac{new\_accum^{-lr\_power}}{learning\_rate} + (2 * l2) \\ + pre\_shrink = \frac{x}{y} \\ + param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\ +} +squared\_accum += grad^2; +$$ + +The paper that proposed Follow The Regularized Leader (FTRL): +(https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker); +REGISTER_OP_CPU_KERNEL(ftrl, + ops::FTRLOpKernel); diff --git a/paddle/operators/ftrl_op.cu b/paddle/operators/ftrl_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..97b36dade6f531df49615ae2d44d565eadba7154 --- /dev/null +++ b/paddle/operators/ftrl_op.cu @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/ftrl_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(ftrl, + ops::FTRLOpKernel); diff --git a/paddle/operators/ftrl_op.h b/paddle/operators/ftrl_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b040162f8d1d8998aa13021c10a25fe57135c1e9 --- /dev/null +++ b/paddle/operators/ftrl_op.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class FTRLOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* param_out = ctx.Output("ParamOut"); + auto* sq_accum_out = ctx.Output("SquaredAccumOut"); + auto* lin_accum_out = ctx.Output("LinearAccumOut"); + + param_out->mutable_data(ctx.GetPlace()); + sq_accum_out->mutable_data(ctx.GetPlace()); + lin_accum_out->mutable_data(ctx.GetPlace()); + + auto grad = ctx.Input("Grad"); + + auto l1 = static_cast(ctx.Attr("l1")); + auto l2 = static_cast(ctx.Attr("l2")); + auto lr_power = static_cast(ctx.Attr("lr_power")); + + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto sq_accum = + EigenVector::Flatten(*ctx.Input("SquaredAccumulator")); + auto lin_accum = + EigenVector::Flatten(*ctx.Input("LinearAccumulator")); + auto g = EigenVector::Flatten(*grad); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + + auto p_out = EigenVector::Flatten(*param_out); + auto s_acc_out = EigenVector::Flatten(*sq_accum_out); + auto l_acc_out = EigenVector::Flatten(*lin_accum_out); + auto place = ctx.GetEigenDevice(); + + Eigen::DSizes grad_dsize(grad->numel()); + + auto new_accum = sq_accum + g * g; + // Special case for lr_power = -0.5 + if (lr_power == static_cast(-0.5)) { + l_acc_out.device(place) = + lin_accum + g - + ((new_accum.sqrt() - sq_accum.sqrt()) / lr.broadcast(grad_dsize)) * p; + } else { + l_acc_out.device(place) = + lin_accum + g - + ((new_accum.pow(-lr_power) - sq_accum.pow(-lr_power)) / + lr.broadcast(grad_dsize)) * + p; + } + + auto x = (l_acc_out.constant(l1) * l_acc_out.sign() - l_acc_out); + if (lr_power == static_cast(-0.5)) { + auto y = (new_accum.sqrt() / lr.broadcast(grad_dsize)) + + l_acc_out.constant(static_cast(2) * l2); + auto pre_shrink = x / y; + p_out.device(place) = + (l_acc_out.abs() > l_acc_out.constant(l1)) + .select(pre_shrink, p.constant(static_cast(0))); + } else { + auto y = (new_accum.pow(-lr_power) / lr.broadcast(grad_dsize)) + + l_acc_out.constant(static_cast(2) * l2); + auto pre_shrink = x / y; + p_out.device(place) = + (l_acc_out.abs() > l_acc_out.constant(l1)) + .select(pre_shrink, p.constant(static_cast(0))); + } + + s_acc_out.device(place) = sq_accum + g * g; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/huber_loss_op.cc b/paddle/operators/huber_loss_op.cc index 3435e74b0afb470fcbd1c0f4e06ad363352cac00..938803d5b36177c782fe40bc34fd92504e5bbf7b 100644 --- a/paddle/operators/huber_loss_op.cc +++ b/paddle/operators/huber_loss_op.cc @@ -70,11 +70,18 @@ input value and Y as the target value. Huber loss can evaluate the fitness of X to Y. Different from MSE loss, Huber loss is more robust for outliers. The shape of X and Y are [batch_size, 1]. The equation is: -L_{\delta}(y, f(x)) = +$$ +Out_{\delta}(X, Y)_i = \begin{cases} -0.5 * (y - f(x))^2, \quad |y - f(x)| \leq \delta \\ -\delta * (|y - f(x)| - 0.5 * \delta), \quad otherwise +0.5 * (Y_i - X_i)^2, +\quad |Y_i - X_i| \leq \delta \\ +\delta * (|Y_i - X_i| - 0.5 * \delta), +\quad otherwise \end{cases} +$$ + +In the above equation, $Out_\delta(X, Y)_i$, $X_i$ and $Y_i$ represent the ith +element of Out, X and Y. )DOC"); } diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 58356a4b7783241ca0292829bf05dc1a8ed80c6c..3018e50a4f54592123df6b9cadd45ce525d7b3e1 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -297,7 +297,25 @@ void set_constant_with_place( template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; -template struct ColwiseSum; +// template struct ColwiseSum; +// The ColwiseSum failed in debug mode, +// and only failed for this case. So reimplemented it. +template <> +void ColwiseSum::operator()( + const platform::DeviceContext& context, const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector->numel(), size); + framework::Tensor one; + one.mutable_data({in_dims[0]}, context.GetPlace()); + SetConstant set; + set(context, &one, static_cast(1.0)); + gemv(context, true, static_cast(in_dims[0]), + static_cast(in_dims[1]), 1.0, + input.data(), one.data(), + 0.0, vector->data()); +} } // namespace math } // namespace operators diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index 075196b47eeaf118a588b96532d87a05e4e600c6..514f2adef284c8877e2e74b943b4e6419c6ae721 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -145,6 +145,8 @@ struct SelectedRowsAddTo { template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; template struct SelectedRowsAddToTensor { @@ -175,6 +177,8 @@ struct SelectedRowsAddToTensor { template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 47fe3b44a50fee9f41ae807793187258159b9f29..c40649e55ef93dec852ff6949b5cb134495e4ebf 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -173,6 +173,8 @@ struct SelectedRowsAddTo { template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; namespace { template @@ -223,6 +225,8 @@ struct SelectedRowsAddToTensor { template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; } // namespace math } // namespace operators diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index c2b7632b2865a3ef66051d815d7722a08c6a8cbd..ddc210c26e69566fef9baa20f49ba1052e993b3f 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -176,4 +176,6 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker, ops::SumOpVarTypeInference); REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu index 5cf05b876b6d6a2ce61d9e10b7ec52ed3cef57d7..5c30dd4d470c2e0acecef18524a4a81f9eb786a9 100644 --- a/paddle/operators/sum_op.cu +++ b/paddle/operators/sum_op.cu @@ -14,4 +14,6 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h index a7d99cde106a0a66f122a8c43f49717c03e60dec..376bb0e6887c797c3c1019e92f738a62d01a9c51 100644 --- a/paddle/platform/cuda_helper.h +++ b/paddle/platform/cuda_helper.h @@ -31,6 +31,16 @@ constexpr int PADDLE_CUDA_NUM_THREADS = 512; // For atomicAdd. USE_CUDA_ATOMIC(Add, float); +USE_CUDA_ATOMIC(Add, int); +USE_CUDA_ATOMIC(Add, unsigned int); +USE_CUDA_ATOMIC(Add, unsigned long long int); + +CUDA_ATOMIC_WRAPPER(Add, int64_t) { + static_assert(sizeof(int64_t) == sizeof(long long int), + "long long should be int64"); + return CudaAtomicAdd(reinterpret_cast(address), + static_cast(val)); +} #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 USE_CUDA_ATOMIC(Add, double); diff --git a/python/paddle/v2/fluid/tests/test_ftrl_op.py b/python/paddle/v2/fluid/tests/test_ftrl_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f77ac4659a9b877829f7ae52dd005d9dd11dac07 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_ftrl_op.py @@ -0,0 +1,62 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFTRLOp(OpTest): + def setUp(self): + self.op_type = "ftrl" + w = np.random.random((102, 105)).astype("float32") + g = np.random.random((102, 105)).astype("float32") + sq_accum = np.full((102, 105), 0.1).astype("float32") + linear_accum = np.full((102, 105), 0.1).astype("float32") + lr = np.array([0.01]).astype("float32") + l1 = 0.1 + l2 = 0.2 + lr_power = -0.5 + + self.inputs = { + 'Param': w, + 'SquaredAccumulator': sq_accum, + 'LinearAccumulator': linear_accum, + 'Grad': g, + 'LearningRate': lr + } + self.attrs = { + 'l1': l1, + 'l2': l2, + 'lr_power': lr_power, + 'learning_rate': lr + } + new_accum = sq_accum + g * g + if lr_power == -0.5: + linear_out = linear_accum + g - ( + (np.sqrt(new_accum) - np.sqrt(sq_accum)) / lr) * w + else: + linear_out = linear_accum + g - ((np.power( + new_accum, -lr_power) - np.power(sq_accum, -lr_power)) / lr) * w + + x = (l1 * np.sign(linear_out) - linear_out) + if lr_power == -0.5: + y = (np.sqrt(new_accum) / lr) + (2 * l2) + pre_shrink = x / y + param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0) + else: + y = (np.power(new_accum, -lr_power) / lr) + (2 * l2) + pre_shrink = x / y + param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0) + + sq_accum_out = sq_accum + g * g + + self.outputs = { + 'ParamOut': param_out, + 'SquaredAccumOut': sq_accum_out, + 'LinearAccumOut': linear_out + } + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index d3dc45742d92dc61b81d9cdc04056c5d5bdc2b63..f88e0b4e15f7115be21ef136cbb96ce96af9d99e 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -21,7 +21,7 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(avg_cost) program.append_backward(avg_cost) - # print str(program) + print str(program) def test_recognize_digits_mlp(self): program = Program() @@ -50,7 +50,8 @@ class TestBook(unittest.TestCase): input=predict, label=label, main_program=program) avg_cost = layers.mean(x=cost, main_program=program) self.assertIsNotNone(avg_cost) - # print str(program) + + print str(program) def test_simple_conv2d(self): program = Program() @@ -65,7 +66,7 @@ class TestBook(unittest.TestCase): filter_size=[4, 4], main_program=program) - # print str(program) + print str(program) def test_recognize_digits_conv(self): program = Program() @@ -104,7 +105,7 @@ class TestBook(unittest.TestCase): program.append_backward(avg_cost) - # print str(program) + print str(program) def test_word_embedding(self): program = Program() @@ -165,7 +166,7 @@ class TestBook(unittest.TestCase): avg_cost = layers.mean(x=cost, main_program=program) self.assertIsNotNone(avg_cost) - # print str(program) + print str(program) def test_linear_chain_crf(self): program = Program() @@ -182,7 +183,7 @@ class TestBook(unittest.TestCase): crf = layers.linear_chain_crf( input=hidden, label=label, main_program=program) - # print str(program) + print str(program) if __name__ == '__main__':