diff --git a/benchmark/paddle/image/vgg.py b/benchmark/paddle/image/vgg.py
index b8429975f5c83df6996e71478fe276b246e8b77b..420884ed8e1ae36a3f1772bfbe8323f3d0ea71e6 100644
--- a/benchmark/paddle/image/vgg.py
+++ b/benchmark/paddle/image/vgg.py
@@ -13,7 +13,7 @@ define_py_data_sources2(
settings(
batch_size=batch_size,
- learning_rate=0.01 / batch_size,
+ learning_rate=0.001 / batch_size,
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * batch_size))
diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake
index 9686df00219001769d074ee815d9cc8db0258496..5a06825beb73e85d8a55b7b578b187bee2c4340c 100644
--- a/cmake/external/mkldnn.cmake
+++ b/cmake/external/mkldnn.cmake
@@ -46,16 +46,20 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
MESSAGE(STATUS "Build MKLDNN with ${MKLDNN_MKLROOT}")
ENDIF()
+SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} -Wno-error=strict-overflow")
+SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} -Wno-error=strict-overflow")
ExternalProject_Add(
${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
- GIT_TAG "v0.10"
+ GIT_TAG "v0.11"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DMKLROOT=${MKLDNN_MKLROOT}
+ CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
+ CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
-DMKLROOT:PATH=${MKLDNN_MKLROOT}
)
diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake
index 74f3279831357c21038df133df0f5a432a6dfd20..20dbc32a738d982df2d3f035206279c82c8de264 100644
--- a/cmake/external/mklml.cmake
+++ b/cmake/external/mklml.cmake
@@ -27,8 +27,8 @@ ENDIF()
INCLUDE(ExternalProject)
SET(MKLML_PROJECT "extern_mklml")
-SET(MKLML_VER "mklml_lnx_2018.0.20170720")
-SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.10/${MKLML_VER}.tgz")
+SET(MKLML_VER "mklml_lnx_2018.0.1.20171007")
+SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.11/${MKLML_VER}.tgz")
SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
SET(MKLML_DST_DIR "mklml")
diff --git a/doc/api/v2/data.rst b/doc/api/v2/data.rst
index fef87c4fbdb452771ecdb361c6eeae5b32bcee14..b56c7332cc284649c7e04328e51a7faa78593a39 100644
--- a/doc/api/v2/data.rst
+++ b/doc/api/v2/data.rst
@@ -2,112 +2,9 @@
Data Reader Interface and DataSets
==================================
+.. toctree::
+ :maxdepth: 1
-DataTypes
-=========
-
-.. automodule:: paddle.v2.data_type
- :members:
- :noindex:
-
-DataFeeder
-==========
-
-.. automodule:: paddle.v2.data_feeder
- :members:
- :noindex:
-
-Reader
-======
-
-.. automodule:: paddle.v2.reader
- :members:
- :noindex:
-
-.. automodule:: paddle.v2.reader.creator
- :members:
- :noindex:
-
-minibatch
-=========
-
-.. automodule:: paddle.v2.minibatch
- :members:
- :noindex:
-
-Dataset
-=======
-
-.. automodule:: paddle.v2.dataset
- :members:
- :noindex:
-
-mnist
-+++++
-
-.. automodule:: paddle.v2.dataset.mnist
- :members:
- :noindex:
-
-cifar
-+++++
-
-.. automodule:: paddle.v2.dataset.cifar
- :members:
- :noindex:
-
-conll05
-+++++++
-
-.. automodule:: paddle.v2.dataset.conll05
- :members: get_dict,get_embedding,test
- :noindex:
-
-imdb
-++++
-
-.. automodule:: paddle.v2.dataset.imdb
- :members:
- :noindex:
-
-imikolov
-++++++++
-
-.. automodule:: paddle.v2.dataset.imikolov
- :members:
- :noindex:
-
-movielens
-+++++++++
-
-.. automodule:: paddle.v2.dataset.movielens
- :members:
- :noindex:
-
-.. autoclass:: paddle.v2.dataset.movielens.MovieInfo
- :noindex:
-
-.. autoclass:: paddle.v2.dataset.movielens.UserInfo
- :noindex:
-
-sentiment
-+++++++++
-
-.. automodule:: paddle.v2.dataset.sentiment
- :members:
- :noindex:
-
-uci_housing
-+++++++++++
-
-.. automodule:: paddle.v2.dataset.uci_housing
- :members:
- :noindex:
-
-wmt14
-+++++
-
-.. automodule:: paddle.v2.dataset.wmt14
- :members:
- :noindex:
-
+ data/data_reader.rst
+ data/image.rst
+ data/dataset.rst
diff --git a/doc/api/v2/data/data_reader.rst b/doc/api/v2/data/data_reader.rst
new file mode 100644
index 0000000000000000000000000000000000000000..2ccfec9c284877a7576e9751526b169a4ac78d8e
--- /dev/null
+++ b/doc/api/v2/data/data_reader.rst
@@ -0,0 +1,36 @@
+=====================
+Data Reader Interface
+=====================
+
+
+DataTypes
+=========
+
+.. automodule:: paddle.v2.data_type
+ :members:
+ :noindex:
+
+DataFeeder
+==========
+
+.. automodule:: paddle.v2.data_feeder
+ :members:
+ :noindex:
+
+Reader
+======
+
+.. automodule:: paddle.v2.reader
+ :members:
+ :noindex:
+
+.. automodule:: paddle.v2.reader.creator
+ :members:
+ :noindex:
+
+minibatch
+=========
+
+.. automodule:: paddle.v2.minibatch
+ :members:
+ :noindex:
diff --git a/doc/api/v2/data/dataset.rst b/doc/api/v2/data/dataset.rst
new file mode 100644
index 0000000000000000000000000000000000000000..6a8ecc5bb1d855e0ded3719943ab3adb810de365
--- /dev/null
+++ b/doc/api/v2/data/dataset.rst
@@ -0,0 +1,75 @@
+Dataset
+=======
+
+.. automodule:: paddle.v2.dataset
+ :members:
+ :noindex:
+
+mnist
++++++
+
+.. automodule:: paddle.v2.dataset.mnist
+ :members:
+ :noindex:
+
+cifar
++++++
+
+.. automodule:: paddle.v2.dataset.cifar
+ :members:
+ :noindex:
+
+conll05
++++++++
+
+.. automodule:: paddle.v2.dataset.conll05
+ :members: get_dict,get_embedding,test
+ :noindex:
+
+imdb
+++++
+
+.. automodule:: paddle.v2.dataset.imdb
+ :members:
+ :noindex:
+
+imikolov
+++++++++
+
+.. automodule:: paddle.v2.dataset.imikolov
+ :members:
+ :noindex:
+
+movielens
++++++++++
+
+.. automodule:: paddle.v2.dataset.movielens
+ :members:
+ :noindex:
+
+.. autoclass:: paddle.v2.dataset.movielens.MovieInfo
+ :noindex:
+
+.. autoclass:: paddle.v2.dataset.movielens.UserInfo
+ :noindex:
+
+sentiment
++++++++++
+
+.. automodule:: paddle.v2.dataset.sentiment
+ :members:
+ :noindex:
+
+uci_housing
++++++++++++
+
+.. automodule:: paddle.v2.dataset.uci_housing
+ :members:
+ :noindex:
+
+wmt14
++++++
+
+.. automodule:: paddle.v2.dataset.wmt14
+ :members:
+ :noindex:
diff --git a/doc/api/v2/data/image.rst b/doc/api/v2/data/image.rst
new file mode 100644
index 0000000000000000000000000000000000000000..97651ffa6be56cf3ecaca2caca38a353fa5c1f49
--- /dev/null
+++ b/doc/api/v2/data/image.rst
@@ -0,0 +1,5 @@
+Image Interface
+===============
+
+.. automodule:: paddle.v2.image
+ :members:
diff --git a/doc/design/ops/images/LOD-and-shape-changes-during-decoding.jpg b/doc/design/ops/images/LOD-and-shape-changes-during-decoding.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8b0d90f7b9d8184b314b0ee4e521f53eb5f1b455
Binary files /dev/null and b/doc/design/ops/images/LOD-and-shape-changes-during-decoding.jpg differ
diff --git a/doc/design/ops/sequence_decoder.md b/doc/design/ops/sequence_decoder.md
new file mode 100644
index 0000000000000000000000000000000000000000..9007aae7a8355ed06c6720a921351f81b859c1fe
--- /dev/null
+++ b/doc/design/ops/sequence_decoder.md
@@ -0,0 +1,245 @@
+# Design: Sequence Decoder Generating LoDTensors
+In tasks such as machine translation and image to text,
+a [sequence decoder](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.md) is necessary to generate sequences.
+
+This documentation describes how to implement the sequence decoder as an operator.
+
+## Beam Search based Decoder
+The [beam search algorithm](https://en.wikipedia.org/wiki/Beam_search) is necessary when generating sequences,
+it is a heuristic search algorithm that explores the paths by expanding the most promising node in a limited set.
+
+In the old version of PaddlePaddle, a C++ class `RecurrentGradientMachine` implements the general sequence decoder based on beam search,
+due to the complexity, the implementation relays on a lot of special data structures,
+quite trivial and hard to be customized by users.
+
+There are a lot of heuristic tricks in the sequence generation tasks,
+so the flexibility of sequence decoder is very important to users.
+
+During PaddlePaddle's refactoring work,
+some new concept is proposed such as [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) and [TensorArray](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/tensor_array.md) that can better support sequence usage,
+and they can help to make the implementation of beam search based sequence decoder **more transparent and modular** .
+
+For example, the RNN sates, candidates IDs and probabilities of beam search can be represented as `LoDTensors`;
+the selected candidate's IDs in each time step can be stored in a `TensorArray`, and `Packed` to the sentences translated.
+
+## Changing LoD's absolute offset to relative offsets
+The current `LoDTensor` is designed to store levels of variable-length sequences,
+it stores several arrays of integers each represents a level.
+
+The integers in each level represents the begin and end (not inclusive) offset of a sequence **in the underlying tensor**,
+let's call this format the **absolute-offset LoD** for clear.
+
+The relative-offset LoD can fast retrieve any sequence but fails to represent empty sequences, for example, a two-level LoD is as follows
+```python
+[[0, 3, 9]
+ [0, 2, 3, 3, 3, 9]]
+```
+The first level tells that there are two sequences:
+- the first's offset is `[0, 3)`
+- the second's offset is `[3, 9)`
+
+while on the second level, there are several empty sequences that both begin and end at `3`.
+It is impossible to tell how many empty second-level sequences exist in the first-level sequences.
+
+There are many scenarios that relay on empty sequence representation,
+such as machine translation or image to text, one instance has no translations or the empty candidate set for a prefix.
+
+So let's introduce another format of LoD,
+it stores **the offsets of the lower level sequences** and is called **relative-offset** LoD.
+
+For example, to represent the same sequences of the above data
+
+```python
+[[0, 3, 6]
+ [0, 2, 3, 3, 3, 9]]
+```
+
+the first level represents that there are two sequences,
+their offsets in the second-level LoD is `[0, 3)` and `[3, 5)`.
+
+The second level is the same with the relative offset example because the lower level is a tensor.
+It is easy to find out the second sequence in the first-level LoD has two empty sequences.
+
+The following demos are based on relative-offset LoD.
+
+## Usage in a simple machine translation model
+Let's start from a simple machine translation model that is simplified from [machine translation chapter](https://github.com/PaddlePaddle/book/tree/develop/08.machine_translation) to draw a simple blueprint of what a sequence decoder can do and how to use it.
+
+The model has an encoder that learns the semantic vector from a sequence,
+and a decoder which uses the sequence decoder to generate new sentences.
+
+**Encoder**
+```python
+import paddle as pd
+
+dict_size = 8000
+source_dict_size = dict_size
+target_dict_size = dict_size
+word_vector_dim = 128
+encoder_dim = 128
+decoder_dim = 128
+beam_size = 5
+max_length = 120
+
+# encoder
+src_word_id = pd.data(
+ name='source_language_word',
+ type=pd.data.integer_value_sequence(source_dict_dim))
+src_embedding = pd.embedding(size=source_dict_size, size=word_vector_dim)
+
+src_word_vec = pd.lookup(src_embedding, src_word_id)
+
+encoder_out_seq = pd.gru(input=src_word_vec, size=encoder_dim)
+
+encoder_ctx = pd.last_seq(encoder_out_seq)
+# encoder_ctx_proj is the learned semantic vector
+encoder_ctx_proj = pd.fc(
+ encoder_ctx, size=decoder_dim, act=pd.activation.Tanh(), bias=None)
+```
+
+**Decoder**
+
+```python
+def generate():
+ decoder = pd.while_loop()
+ with decoder.step():
+ decoder_mem = decoder.memory(init=encoder_ctx) # mark the memory
+ generated_ids = decoder.memory() # TODO init to batch_size s
+ generated_scores = decoder.memory() # TODO init to batch_size 1s or 0s
+
+ target_word = pd.lookup(trg_embedding, gendrated_ids)
+ # expand encoder_ctx's batch to fit target_word's lod
+ # for example
+ # decoder_mem.lod is
+ # [[0 1 3],
+ # [0 1 3 6]]
+ # its tensor content is [a1 a2 a3 a4 a5]
+ # which means there are 2 sentences to translate
+ # - the first sentence has 1 translation prefixes, the offsets are [0, 1)
+ # - the second sentence has 2 translation prefixes, the offsets are [1, 3) and [3, 6)
+ # the target_word.lod is
+ # [[0, 1, 6]
+ # [0, 2, 4, 7, 9 12]]
+ # which means 2 sentences to translate, each has 1 and 5 prefixes
+ # the first prefix has 2 candidates
+ # the following has 2, 3, 2, 3 candidates
+ # the encoder_ctx_expanded's content will be
+ # [a1 a1 a2 a2 a3 a3 a3 a4 a4 a5 a5 a5]
+ encoder_ctx_expanded = pd.lod_expand(encoder_ctx, target_word)
+ decoder_input = pd.fc(
+ act=pd.activation.Linear(),
+ input=[target_word, encoder_ctx],
+ size=3 * decoder_dim)
+ gru_out, cur_mem = pd.gru_step(
+ decoder_input, mem=decoder_mem, size=decoder_dim)
+ scores = pd.fc(
+ gru_out,
+ size=trg_dic_size,
+ bias=None,
+ act=pd.activation.Softmax())
+ # K is an config
+ topk_scores, topk_ids = pd.top_k(scores, K)
+ topk_generated_scores = pd.add_scalar(topk_scores, generated_scores)
+
+ selected_ids, selected_generation_scores = decoder.beam_search(
+ topk_ids, topk_generated_scores)
+
+ # update the states
+ decoder_mem.update(cur_mem) # tells how to update state
+ generated_ids.update(selected_ids)
+ generated_scores.update(selected_generation_scores)
+
+ decoder.output(selected_ids)
+ decoder.output(selected_generation_scores)
+
+translation_ids, translation_scores = decoder()
+```
+The `decoder.beam_search` is a operator that given the candidates and the scores of translations including the candidates,
+return the result of the beam search algorithm.
+
+In this way, users can customize anything on the inputs or outputs of beam search, for example, two ways to prune some translation prefixes
+
+1. meke the correspondind elements in `topk_generated_scores` zero or some small values, beam_search will discard this candidate.
+2. remove some specific candidate in `selected_ids`
+3. get the final `translation_ids`, remove the translation sequence in it.
+
+The implementation of sequence decoder can reuse the C++ class [RNNAlgorithm](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/paddle/operators/dynamic_recurrent_op.h#L30),
+so the python syntax is quite similar to a [RNN](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/doc/design/block.md#blocks-with-for-and-rnnop).
+
+Both of them are two-level `LoDTensors`
+
+- the first level represents `batch_size` of (source) sentences;
+- the second level represents the candidate ID sets for translation prefix.
+
+for example, 3 source sentences to translate, and has 2, 3, 1 candidates.
+
+Unlike an RNN, in sequence decoder, the previous state and the current state have different LoD and shape,
+a `lod_expand` operator is used to expand the LoD of the previous state to fit the current state.
+
+For example, the previous state
+
+* LoD is `[0, 1, 3][0, 2, 5, 6]`
+* content of tensor is `a1 a2 b1 b2 b3 c1`
+
+the current state stored in `encoder_ctx_expanded`
+
+* LoD is `[0, 2, 7][0 3 5 8 9 11 11]`
+* the content is
+ - a1 a1 a1 (a1 has 3 candidates, so the state should be copied 3 times for each candidates)
+ - a2 a2
+ - b1 b1 b1
+ - b2
+ - b3 b3
+ - None (c1 has 0 candidates, so c1 is dropped)
+
+Benefit from the relative offset LoD, empty candidate set can be represented naturally.
+
+the status in each time step can be stored in `TensorArray`, and `Pack`ed to a final LoDTensor, the corresponding syntax is
+
+```python
+decoder.output(selected_ids)
+decoder.output(selected_generation_scores)
+```
+
+the `selected_ids` is the candidate ids for the prefixes,
+it will be `Packed` by `TensorArray` to a two-level `LoDTensor`,
+the first level represents the source sequences,
+the second level represents generated sequences.
+
+Pack the `selected_scores` will get a `LoDTensor` that stores scores of each candidate of translations.
+
+Pack the `selected_generation_scores` will get a `LoDTensor`, and each tail is the probability of the translation.
+
+## LoD and shape changes during decoding
+
+
+
+
+According the image above, the only phrase to change LoD is beam search.
+
+## Beam search design
+The beam search algorthm will be implemented as one method of the sequence decoder, it has 3 inputs
+
+1. `topk_ids`, top K candidate ids for each prefix.
+2. `topk_scores`, the corresponding scores for `topk_ids`
+3. `generated_scores`, the score of the prefixes.
+
+All of the are LoDTensors, so that the sequence affilication is clear.
+Beam search will keep a beam for each prefix and select a smaller candidate set for each prefix.
+
+It will return three variables
+
+1. `selected_ids`, the final candidate beam search function selected for the next step.
+2. `selected_scores`, the scores for the candidates.
+3. `generated_scores`, the updated scores for each prefixes (with the new candidates appended).
+
+## Introducing the LoD-based `Pack` and `Unpack` methods in `TensorArray`
+The `selected_ids`, `selected_scores` and `generated_scores` are LoDTensors,
+and they exist in each time step,
+so it is natural to store them in arrays.
+
+Currently, PaddlePaddle has a module called `TensorArray` which can store an array of tensors,
+the results of beam search are better to store in a `TensorArray`.
+
+The `Pack` and `UnPack` in `TensorArray` are used to package tensors in the array to a `LoDTensor` or split the `LoDTensor` to an array of tensors.
+It needs some extensions to support pack or unpack an array of `LoDTensors`.
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 4e8d630c2634682ff63b38182108eadebb5c7ff9..d485cdf6109274377ad0057223bdd8401e964aa7 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -21,7 +21,7 @@
#include "paddle/framework/var_desc.h"
#include "paddle/operators/net_op.h"
-USE_OP(fill_constant);
+USE_NO_KERNEL_OP(fill_constant);
namespace paddle {
namespace framework {
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index c5ae7b185460c8b0d68ba38bb9db9bd3d3fb14ea..3ec88d7a72c3339bf5e7d0ca3957a3f608f039b7 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -34,6 +34,21 @@ inline DataType ToDataType(std::type_index type) {
}
}
+inline std::type_index ToTypeIndex(DataType type) {
+ switch (type) {
+ case DataType::FP32:
+ return typeid(float);
+ case DataType::FP64:
+ return typeid(double);
+ case DataType::INT32:
+ return typeid(int);
+ case DataType::INT64:
+ return typeid(int64_t);
+ default:
+ PADDLE_THROW("Not support type %d", type);
+ }
+}
+
template
inline void VisitDataType(DataType type, Visitor visitor) {
switch (type) {
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 10c785e04c4fa2192f9c95513009cf7d8c123868..53b899a23997b71e723a298ec360a4e018d89878 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -79,6 +79,13 @@ DDim make_ddim(const std::vector& dims) {
return result;
}
+DDim make_ddim(const std::vector& dims) {
+ std::vector res(dims.size());
+ std::transform(dims.begin(), dims.end(), res.begin(),
+ [](int d) { return static_cast(d); });
+ return make_ddim(res);
+}
+
/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes errors
class DynamicMutableIndexer : public boost::static_visitor {
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index aa773868ab4b68acbc46dfa2cd2569d8b8b7789d..4ca5e49566b7ec006eba80f3f9808bacb1ff2615 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -81,6 +81,8 @@ struct DDim {
*/
DDim make_ddim(const std::vector& dims);
+DDim make_ddim(const std::vector& dims);
+
/**
* \brief Make a DDim from an initializer list
*
diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt
index 4fd72d64a90ae6f16dd1499ceb7fba6e40fe4cea..9b2779b42cad324253dadf27dbff20fd8e8c8e16 100644
--- a/paddle/function/CMakeLists.txt
+++ b/paddle/function/CMakeLists.txt
@@ -45,6 +45,7 @@ if(WITH_GPU)
add_simple_unittest(BlockExpandOpTest)
add_simple_unittest(CropOpTest)
add_simple_unittest(SwitchOpTest)
+ add_simple_unittest(ScaleSubRegionOpTest)
endif()
add_simple_unittest(Im2ColTest)
diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h
index ba446bf92da264fafa1fb47a2c30da9cb13176ce..370940532ef40335be54a3e6467de0409e923ec4 100644
--- a/paddle/function/FunctionTest.h
+++ b/paddle/function/FunctionTest.h
@@ -110,6 +110,7 @@ public:
function2_(FunctionBase::funcRegistrar_.createByType(name2)) {
function1_->init(config);
function2_->init(config);
+ initArgsCallback_ = nullptr;
}
~Compare2Function() {}
@@ -170,6 +171,10 @@ public:
*seq2_));
}
+ void registerInitCallback(std::function callback) {
+ initArgsCallback_ = callback;
+ }
+
// output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size =
@@ -340,6 +345,10 @@ protected:
initArg(*func1Inputs_[i]);
}
+ if (initArgsCallback_ != nullptr) {
+ initArgsCallback_(*func1Inputs_[i], i);
+ }
+
copyArg_(*func1Inputs_[i], *func2Inputs_[i]);
}
}
@@ -386,6 +395,7 @@ protected:
std::shared_ptr seq1_;
std::shared_ptr seq2_;
test::CopyArgument copyArg_;
+ std::function initArgsCallback_;
};
class CpuGpuFuncCompare
diff --git a/paddle/function/ScaleSubRegionOp.cpp b/paddle/function/ScaleSubRegionOp.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a080505d7df83a6c0a9d88fbcb7863fc0e1f7b21
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOp.cpp
@@ -0,0 +1,155 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionOp.h"
+#include "paddle/function/TensorShape.h"
+
+namespace paddle {
+
+template <>
+void ScaleSubRegion(real* outputs,
+ const real* inputs,
+ const real* indices,
+ const TensorShape shape,
+ const FuncConfig& conf) {
+ real value = conf.get("value");
+
+ int number = shape[0];
+ int channel = shape[1];
+ int height = shape[2];
+ int width = shape[3];
+
+ memcpy(outputs, inputs, number * channel * height * width * sizeof(real));
+
+ for (int n = 0; n < number; ++n) {
+ // indices start from 1
+ int offset = n * 6;
+ for (int c = indices[offset] - 1; c < indices[offset + 1]; ++c) {
+ for (int h = indices[offset + 2] - 1; h < indices[offset + 3]; ++h) {
+ for (int w = indices[offset + 4] - 1; w < indices[offset + 5]; ++w) {
+ int idx = ((n * channel + c) * height + h) * width + w;
+ outputs[idx] *= value;
+ }
+ }
+ }
+ }
+}
+
+template <>
+void ScaleSubRegionGrad(const real* inGrad,
+ real* outGrad,
+ const real* indices,
+ const TensorShape shape,
+ const FuncConfig& conf) {
+ real value = conf.get("value");
+
+ int number = shape[0];
+ int channel = shape[1];
+ int height = shape[2];
+ int width = shape[3];
+
+ for (int n = 0; n < number; ++n) {
+ for (int c = 0; c < channel; ++c) {
+ for (int h = 0; h < height; ++h) {
+ for (int w = 0; w < width; ++w) {
+ int idx = ((n * channel + c) * height + h) * width + w;
+ int offset = n * 6;
+ if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+ h >= (indices[offset + 2] - 1) &&
+ h <= (indices[offset + 3] - 1) &&
+ w >= (indices[offset + 4] - 1) &&
+ w <= (indices[offset + 5] - 1)) {
+ outGrad[idx] += inGrad[idx] * value;
+ } else {
+ outGrad[idx] += inGrad[idx];
+ }
+ }
+ }
+ }
+ }
+}
+
+/**
+ * \brief For each instance, ScaleSubRegion can be used to multiply a value to
+ * a specified sub continuous region. By providing start index and end
+ * index for C/H/W, you can specify the location and shape of the region.
+ *
+ * Argument in this Function:
+ * \param inputs A 4-D tensor with shape [N, C, H, W], only one input.
+ * \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
+ * \param outputs A 4-D tensor with same shape as inputs, output value.
+ */
+template
+class ScaleSubRegionFunc : public FunctionBase {
+public:
+ void init(const FuncConfig& config) override { conf_ = config; }
+
+ void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+ CHECK_EQ(2UL, inputs.size());
+ CHECK_EQ(1UL, outputs.size());
+ CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
+
+ TensorShape shape = inputs[0].shape();
+
+ ScaleSubRegion(outputs[0].data(),
+ inputs[0].data(),
+ inputs[1].data(),
+ shape,
+ conf_);
+ }
+
+private:
+ FuncConfig conf_;
+};
+
+/**
+ * \brief The backward propagation of ScaleSubRegion Function.
+ *
+ * Argument in this Function:
+ * \param inputs A 4-D tensor with shape [N, C, H, W], output gradient.
+ * \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
+ * \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value.
+ */
+
+template
+class ScaleSubRegionGradFunc : public FunctionBase {
+public:
+ void init(const FuncConfig& config) override { conf_ = config; }
+
+ void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+ CHECK_EQ(2UL, inputs.size());
+ CHECK_EQ(1UL, outputs.size());
+ CHECK_EQ(outputs[0].getArgType(), ADD_TO);
+
+ TensorShape shape = inputs[0].shape();
+
+ ScaleSubRegionGrad(inputs[0].data(),
+ outputs[0].data(),
+ inputs[1].data(),
+ shape,
+ conf_);
+ }
+
+private:
+ FuncConfig conf_;
+};
+
+REGISTER_TYPED_FUNC(ScaleSubRegion, CPU, ScaleSubRegionFunc);
+REGISTER_TYPED_FUNC(ScaleSubRegionGrad, CPU, ScaleSubRegionGradFunc);
+#ifdef PADDLE_WITH_CUDA
+REGISTER_TYPED_FUNC(ScaleSubRegion, GPU, ScaleSubRegionFunc);
+REGISTER_TYPED_FUNC(ScaleSubRegionGrad, GPU, ScaleSubRegionGradFunc);
+#endif
+
+} // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOp.h b/paddle/function/ScaleSubRegionOp.h
new file mode 100644
index 0000000000000000000000000000000000000000..0480c8577f3fbf3bc9e94b635df96a31b103e9e3
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOp.h
@@ -0,0 +1,55 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Function.h"
+
+namespace paddle {
+
+/**
+ * \brief Function to multiply a value to values in specified sub continuous
+ * region. Indices must be provided to indcate the location and shape of
+ * the region and the multiplied value is passed by configure variable.
+ *
+ *
+ * \param[out] outputs Output value.
+ * \param[in] inputs Input data which contains NCHW information.
+ * \param[in] indices Indices data to indcate the sub region.
+ * \param[in] shape Tensor shape of input value.
+ * \param[in] conf Configure variable which contains the multiplied value.
+ */
+template
+void ScaleSubRegion(real* outputs,
+ const real* inputs,
+ const real* indices,
+ const TensorShape shape,
+ const FuncConfig& conf);
+
+/**
+ * \brief Backward propagation function of ScaleSubRegion.
+ *
+ * \param[out] inGrad Gradients of previous layer.
+ * \param[in] outGrad Output gradient.
+ * \param[in] indices Indices data.
+ * \param[in] shape The Shape of input tensor.
+ * \param[in] conf Configure variable.
+ */
+template
+void ScaleSubRegionGrad(const real* inGrad,
+ real* outGrad,
+ const real* indices,
+ const TensorShape shape,
+ const FuncConfig& conf);
+} // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOpGpu.cu b/paddle/function/ScaleSubRegionOpGpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8aae2e44c3fdc8b516e66ecfd2e04f466a17dde9
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOpGpu.cu
@@ -0,0 +1,116 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionOp.h"
+#include "hl_base.h"
+
+namespace paddle {
+
+__global__ void KeScaleSubRegion(real* outputs,
+ const real* inputs,
+ const real* indices,
+ real value,
+ int channel,
+ int height,
+ int width,
+ int nthreads) {
+ const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (idx < nthreads) {
+ const int w = idx % width;
+ const int h = (idx / width) % height;
+ const int c = (idx / width / height) % channel;
+ const int n = idx / width / height / channel;
+
+ const int offset = n * 6;
+ if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+ h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
+ w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
+ outputs[idx] = inputs[idx] * value;
+ } else {
+ outputs[idx] = inputs[idx];
+ }
+ }
+}
+
+template <>
+void ScaleSubRegion(real* outputs,
+ const real* inputs,
+ const real* indices,
+ const TensorShape shape,
+ const FuncConfig& conf) {
+ real value = conf.get("value");
+
+ int number = shape[0];
+ int channel = shape[1];
+ int height = shape[2];
+ int width = shape[3];
+
+ size_t nth = number * channel * height * width;
+ int blockSize = 1024;
+ int gridSize = (nth + blockSize - 1) / blockSize;
+
+ KeScaleSubRegion<<>>(
+ outputs, inputs, indices, value, channel, height, width, nth);
+ CHECK_SYNC("ScaleSubRegion");
+}
+
+__global__ void KeScaleSubRegionDiff(const real* inGrad,
+ real* outGrad,
+ const real* indices,
+ real value,
+ int channel,
+ int height,
+ int width,
+ int nthreads) {
+ const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (idx < nthreads) {
+ const int w = idx % width;
+ const int h = (idx / width) % height;
+ const int c = (idx / width / height) % channel;
+ const int n = idx / width / height / channel;
+
+ const int offset = n * 6;
+ if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+ h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
+ w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
+ outGrad[idx] += inGrad[idx] * value;
+ } else {
+ outGrad[idx] += inGrad[idx];
+ }
+ }
+}
+
+template <>
+void ScaleSubRegionGrad(const real* inGrad,
+ real* outGrad,
+ const real* indices,
+ const TensorShape shape,
+ const FuncConfig& conf) {
+ real value = conf.get("value");
+
+ int number = shape[0];
+ int channel = shape[1];
+ int height = shape[2];
+ int width = shape[3];
+
+ size_t nth = number * channel * height * width;
+ int blockSize = 1024;
+ int gridSize = (nth + blockSize - 1) / blockSize;
+
+ KeScaleSubRegionDiff<<>>(
+ inGrad, outGrad, indices, value, channel, height, width, nth);
+ CHECK_SYNC("ScaleSubRegionGrad");
+}
+
+} // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOpTest.cpp b/paddle/function/ScaleSubRegionOpTest.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..43331f258dddaa43cbc8cc77519e299de7e98290
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOpTest.cpp
@@ -0,0 +1,72 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include "FunctionTest.h"
+
+namespace paddle {
+
+TEST(ScaleSubRegion, real) {
+ for (size_t numSamples : {5, 32}) {
+ for (size_t channels : {5, 32}) {
+ for (size_t imgSizeH : {5, 33}) {
+ for (size_t imgSizeW : {5, 32}) {
+ for (real value : {-0.5, 0.0, 0.5}) {
+ for (bool firstHalf : {false, true}) {
+ VLOG(3) << " numSamples=" << numSamples
+ << " channels=" << channels << " imgSizeH=" << imgSizeH
+ << " imgSizeW=" << imgSizeW;
+
+ for (bool testGrad : {false, true}) {
+ CpuGpuFuncCompare compare(
+ testGrad ? "ScaleSubRegionGrad" : "ScaleSubRegion",
+ FuncConfig().set("value", value));
+
+ TensorShape shape{numSamples, channels, imgSizeH, imgSizeW};
+ TensorShape indicesShape{numSamples, 6};
+
+ compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape));
+ compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, indicesShape));
+
+ compare.registerInitCallback([=](BufferArg& arg, size_t index) {
+ if (index == 1) {
+ real* data = (real*)arg.data();
+
+ for (size_t i = 0; i < numSamples; ++i) {
+ size_t offset = i * 6;
+ data[offset] = firstHalf ? 1 : channels / 2;
+ data[offset + 1] = firstHalf ? channels / 2 : channels;
+ data[offset + 2] = firstHalf ? 1 : imgSizeH / 2;
+ data[offset + 3] = firstHalf ? imgSizeH / 2 : imgSizeH;
+ data[offset + 4] = firstHalf ? 1 : imgSizeW / 2;
+ data[offset + 5] = firstHalf ? imgSizeW / 2 : imgSizeW;
+ }
+ }
+ });
+
+ compare.addOutputs(
+ BufferArg(
+ VALUE_TYPE_FLOAT, shape, testGrad ? ADD_TO : ASSIGN_TO),
+ testGrad ? ADD_TO : ASSIGN_TO);
+ compare.run();
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp
index 8eb700723f2cf7dda969739bb5e3d48358d278a0..6ffe4fbec643e50d27924a989875454d307f5b9b 100644
--- a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp
@@ -62,16 +62,14 @@ void MKLDNNAddtoLayer::resetFwd(std::vector& pipeline,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
- if (biases_) {
- LOG(FATAL) << "not implemented yet";
- }
- resetFwdBuffers(inVals_, out);
+ resetFwdBuffers(inVals_, bias, out);
in = inVals_[0];
std::shared_ptr fwdPD;
- resetFwdPD(fwdPD, inVals_, out);
+ std::shared_ptr biasPD;
+ resetFwdPD(fwdPD, biasPD, inVals_, bias, out);
- resetFwdPipeline(pipeline, fwdPD, inVals_, out);
+ resetFwdPipeline(pipeline, fwdPD, biasPD, inVals_, bias, out);
}
void MKLDNNAddtoLayer::resetBwd(std::vector& pipeline,
@@ -79,7 +77,7 @@ void MKLDNNAddtoLayer::resetBwd(std::vector& pipeline,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
- resetBwdBuffers(inGrads_, out);
+ resetBwdBuffers(inGrads_, bias, out);
in = inGrads_[0];
// backward only need share output grad to input grad
@@ -89,6 +87,20 @@ void MKLDNNAddtoLayer::resetBwd(std::vector& pipeline,
inputLayers_[i]->getOutputGrad()->setData(inGrads_[i]->getData());
}
}
+
+ // backward bias
+ bwdBias_ = nullptr;
+ if (bias) {
+ std::vector scales(bs_, 1.0);
+ std::vector srcPDs(bs_, bias->getPrimitiveDesc());
+ auto biasPD = sum::primitive_desc(bias->getMemoryDesc(), scales, srcPDs);
+ std::vector srcs;
+ for (size_t i = 0; i < grads_.size(); ++i) {
+ srcs.push_back(*(grads_[i]));
+ }
+ bwdBias_.reset(new sum(biasPD, srcs, *bias));
+ pipeline.push_back(*bwdBias_);
+ }
}
void MKLDNNAddtoLayer::updateWeights(const UpdateCallback& callback) {
@@ -97,7 +109,25 @@ void MKLDNNAddtoLayer::updateWeights(const UpdateCallback& callback) {
}
}
+void MKLDNNAddtoLayer::prepareBias(MKLDNNMatrixPtr& bias,
+ const MatrixPtr& biasMat,
+ const MKLDNNMatrixPtr& out,
+ std::vector& outs) {
+ auto pd = MKLDNNMatrix::createPrimitiveDesc(
+ {(int)layerSize_}, memory::format::x, engine_);
+ bias = MKLDNNMatrix::create(pd, biasMat);
+ outs.clear();
+ real* data = out->getData();
+ CHECK_EQ(bs_ * layerSize_, out->getElementCnt());
+ for (int i = 0; i < bs_; ++i) {
+ MatrixPtr tmp =
+ Matrix::create(data + i * layerSize_, 1, layerSize_, false, false);
+ outs.push_back(MKLDNNMatrix::create(bias->getPrimitiveDesc(), tmp));
+ }
+}
+
void MKLDNNAddtoLayer::resetFwdBuffers(std::vector& inputs,
+ MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
inputs.resize(inputLayers_.size());
for (size_t i = 0; i < inputs.size(); i++) {
@@ -110,12 +140,20 @@ void MKLDNNAddtoLayer::resetFwdBuffers(std::vector& inputs,
}
resetOutValue(out, inputs[0]->getPrimitiveDesc());
+
+ if (biases_ && biases_->getW()) {
+ prepareBias(bias, biases_->getW(), out, vals_);
+ } else {
+ bias = nullptr;
+ }
}
void MKLDNNAddtoLayer::resetFwdPD(std::shared_ptr& pd,
+ std::shared_ptr& biasPD,
std::vector& inputs,
+ MKLDNNMatrixPtr bias,
MKLDNNMatrixPtr out) {
- std::vector scales(inputs.size(), 1.0);
+ std::vector scales(inputs.size(), 1.0);
std::vector srcPDs;
for (size_t i = 0; i < inputs.size(); i++) {
srcPDs.push_back(inputs[i]->getPrimitiveDesc());
@@ -123,12 +161,23 @@ void MKLDNNAddtoLayer::resetFwdPD(std::shared_ptr& pd,
CHECK(out);
pd.reset(new sum::primitive_desc(out->getMemoryDesc(), scales, srcPDs));
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
+
+ biasPD = nullptr;
+ if (bias) {
+ std::vector scales(2, 1.0);
+ std::vector srcPDs(2, bias->getPrimitiveDesc());
+ biasPD.reset(
+ new sum::primitive_desc(bias->getMemoryDesc(), scales, srcPDs));
+ CHECK_PRIMITIVE_DESC_EQ(bias, biasPD->dst_primitive_desc());
+ }
}
void MKLDNNAddtoLayer::resetFwdPipeline(
std::vector& pipeline,
std::shared_ptr& pd,
+ std::shared_ptr& biasPD,
std::vector& inputs,
+ MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
std::vector srcs;
for (size_t i = 0; i < inputs.size(); i++) {
@@ -136,9 +185,23 @@ void MKLDNNAddtoLayer::resetFwdPipeline(
}
fwd_.reset(new sum(*pd, srcs, *out));
pipeline.push_back(*fwd_);
+
+ fwdBias_.clear();
+ if (biasPD == nullptr || bias == nullptr) {
+ return;
+ }
+ fwdBias_.resize(vals_.size());
+ for (size_t i = 0; i < vals_.size(); ++i) {
+ std::vector srcs;
+ srcs.push_back(*(vals_[i]));
+ srcs.push_back(*bias);
+ fwdBias_[i].reset(new sum(*biasPD, srcs, *vals_[i]));
+ pipeline.push_back(*fwdBias_[i]);
+ }
}
void MKLDNNAddtoLayer::resetBwdBuffers(std::vector& inputs,
+ MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
CHECK(outVal_);
resetOutGrad(out, outVal_->getPrimitiveDesc());
@@ -149,6 +212,12 @@ void MKLDNNAddtoLayer::resetBwdBuffers(std::vector& inputs,
resetInGrad(inputs[i], inVal_->getPrimitiveDesc(), i);
CHECK_PRIMITIVE_DESC_EQ(inputs[i], out->getPrimitiveDesc());
}
+
+ if (biases_ && biases_->getWGrad()) {
+ prepareBias(bias, biases_->getWGrad(), out, grads_);
+ } else {
+ bias = nullptr;
+ }
}
} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNAddtoLayer.h b/paddle/gserver/layers/MKLDNNAddtoLayer.h
index 15f74ec5bdf3d1e4ae5e09051be6be418590a67a..24504b7b4f50726e2b2757ca3029461cdc27b411 100644
--- a/paddle/gserver/layers/MKLDNNAddtoLayer.h
+++ b/paddle/gserver/layers/MKLDNNAddtoLayer.h
@@ -32,9 +32,15 @@ protected:
// layer size == ic * ih * iw == oc * oh *ow, and can not be changed
size_t layerSize_;
- // TODO(TJ): this part has not been optimized by MKL-DNN
std::unique_ptr biases_;
+ // buffers for adding bias
+ std::vector vals_;
+ std::vector grads_;
+ // primitives for adding bias
+ std::vector> fwdBias_;
+ std::shared_ptr bwdBias_;
+
public:
explicit MKLDNNAddtoLayer(const LayerConfig& config) : MKLDNNLayer(config) {}
@@ -91,20 +97,34 @@ protected:
* reset pipeline.
*/
void resetFwdBuffers(std::vector& inputs,
+ MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr& pd,
+ std::shared_ptr& biasPD,
std::vector& inputs,
+ MKLDNNMatrixPtr bias,
MKLDNNMatrixPtr out);
void resetFwdPipeline(std::vector& pipeline,
std::shared_ptr& pd,
+ std::shared_ptr& biasPD,
std::vector& inputs,
+ MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(inputs, output, bias)
*/
void resetBwdBuffers(std::vector& inputs,
+ MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
+
+ /**
+ * prepare for bias
+ */
+ void prepareBias(MKLDNNMatrixPtr& bias,
+ const MatrixPtr& biasMat,
+ const MKLDNNMatrixPtr& out,
+ std::vector& outs);
};
} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp
index 82ef344c7b2aa0093a5f0a28780592dea5d51efe..e75ac5ba4647a8267b7bc189893bd7adb5c3053f 100644
--- a/paddle/gserver/layers/MKLDNNLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNLayer.cpp
@@ -287,7 +287,7 @@ void MKLDNNLayer::resetMergeGrad(MKLDNNMatrixPtr& out) {
return;
}
CHECK(out) << "should have reset internal ouput grad";
- std::vector scales(outputMap_.size(), 1.0);
+ std::vector scales(outputMap_.size(), 1.0);
std::vector srcPDs;
std::vector srcs;
for (auto it = outputMap_.begin(); it != outputMap_.end(); ++it) {
diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.cpp b/paddle/gserver/layers/ScaleSubRegionLayer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..aa6778aef4e893208fd064ca22e217c6c4d960f9
--- /dev/null
+++ b/paddle/gserver/layers/ScaleSubRegionLayer.cpp
@@ -0,0 +1,78 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionLayer.h"
+#include "paddle/utils/Stat.h"
+namespace paddle {
+
+REGISTER_LAYER(scale_sub_region, ScaleSubRegionLayer);
+
+bool ScaleSubRegionLayer::init(const LayerMap& layerMap,
+ const ParameterMap& parameterMap) {
+ Layer::init(layerMap, parameterMap);
+ CHECK_EQ(static_cast(inputLayers_.size()), 2);
+ auto& conf = config_.inputs(0).scale_sub_region_conf();
+ value_ = conf.value();
+
+ createFunction(forward_, "ScaleSubRegion", FuncConfig().set("value", value_));
+ createFunction(
+ backward_, "ScaleSubRegionGrad", FuncConfig().set("value", value_));
+
+ return true;
+}
+
+void ScaleSubRegionLayer::forward(PassType passType) {
+ Layer::forward(passType);
+ auto in0 = getInput(0);
+ imgH_ = in0.getFrameHeight();
+ imgW_ = in0.getFrameWidth();
+ if (imgH_ == 0 || imgW_ == 0) {
+ auto& conf = config_.inputs(0).scale_sub_region_conf();
+ imgH_ = conf.image_conf().img_size_y();
+ imgW_ = conf.image_conf().img_size();
+ }
+ MatrixPtr imgV = in0.value;
+ size_t batchSize = imgV->getHeight();
+ size_t spatialSize = imgH_ * imgW_;
+ channelsNum_ = imgV->getWidth() / spatialSize;
+ shape_ = TensorShape({batchSize, channelsNum_, imgH_, imgW_});
+
+ resetOutput(batchSize, imgV->getWidth());
+ auto& out = getOutput();
+ out.setFrameHeight(imgH_);
+ out.setFrameWidth(imgW_);
+
+ MatrixPtr indicesV = getInputValue(1);
+ indicesShape_ = TensorShape({batchSize, 6});
+
+ REGISTER_TIMER_INFO("ScaleSubRegionForward", getName().c_str());
+ BufferArgs inArgs;
+ BufferArgs outArgs;
+ inArgs.addArg(*imgV, shape_);
+ inArgs.addArg(*indicesV, indicesShape_);
+ outArgs.addArg(*out.value, shape_, ASSIGN_TO);
+ forward_[0]->calc(inArgs, outArgs);
+}
+
+void ScaleSubRegionLayer::backward(const UpdateCallback& callback) {
+ REGISTER_TIMER_INFO("ScaleSubRegionBackward", getName().c_str());
+ BufferArgs inArgs;
+ BufferArgs outArgs;
+ inArgs.addArg(*getOutputGrad(), shape_);
+ inArgs.addArg(*getInputValue(1), indicesShape_);
+ outArgs.addArg(*getInputGrad(0), shape_, ADD_TO);
+ backward_[0]->calc(inArgs, outArgs);
+}
+
+} // namespace paddle
diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.h b/paddle/gserver/layers/ScaleSubRegionLayer.h
new file mode 100644
index 0000000000000000000000000000000000000000..a27c56de93bb6fdde0f95cd4c5abe5dfabe4e858
--- /dev/null
+++ b/paddle/gserver/layers/ScaleSubRegionLayer.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Layer.h"
+
+namespace paddle {
+
+/**
+ * \brief For each instance, this layer can be used to multiply a value to a
+ * specified sub continuous region. By providing start index and end
+ * index for C/H/W, you can specify the location and shape of the
+ * region.
+ *
+ * input_0: Input value.
+ * input_1: Indices value to specify the location an shape of the
+ * region.
+ */
+class ScaleSubRegionLayer : public Layer {
+public:
+ explicit ScaleSubRegionLayer(const LayerConfig& config) : Layer(config) {}
+
+ ~ScaleSubRegionLayer() {}
+
+ bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
+
+ void forward(PassType passType);
+
+ void backward(const UpdateCallback& callback = nullptr);
+
+protected:
+ TensorShape shape_;
+ TensorShape indicesShape_;
+ size_t imgH_;
+ size_t imgW_;
+ size_t channelsNum_;
+ real value_;
+};
+
+} // namespace paddle
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index 1a46fb49153a0aa4228f58db481b950bc2d6de83..df73e6781533def5641635e9dfa9c9e4e8a0b57f 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -53,7 +53,7 @@ TEST(Operator, dot_mul) {
TEST(Projection, context) {
for (auto contextStart : {-5, -3, -1, 0, 3}) {
for (auto contextLength : {1, 2, 5, 7}) {
- for (auto batchSize : {1, 2, 5, 20, 50}) {
+ for (auto batchSize : {1, 2, 5, 20}) {
for (auto trainablePadding : {false, true}) {
LOG(INFO) << " contextStart=" << contextStart
<< " contextLength=" << contextLength
@@ -585,14 +585,14 @@ TEST(Layer, maxoutLayer) {
}
void testFcLayer(string format, size_t nnz) {
TestConfig config;
- config.biasSize = 4096;
+ config.biasSize = 1024;
config.layerConfig.set_type("fc");
- config.layerConfig.set_size(4096);
+ config.layerConfig.set_size(1024);
config.layerConfig.set_active_type("sigmoid");
config.layerConfig.set_drop_rate(0.1);
config.inputDefs.push_back(
- {INPUT_DATA, "layer_0", 8192, nnz, ParaSparse(format)});
+ {INPUT_DATA, "layer_0", 2048, nnz, ParaSparse(format)});
config.layerConfig.add_inputs();
LOG(INFO) << config.inputDefs[0].sparse.sparse << " "
@@ -609,9 +609,9 @@ void testFcLayer(string format, size_t nnz) {
}
TEST(Layer, fcLayer) {
- testFcLayer("", 4096 * 4096 * 2);
- testFcLayer("csc", 4096 * 40);
- testFcLayer("csr", 4096 * 40);
+ testFcLayer("", 1024 * 1024 * 2);
+ testFcLayer("csc", 1024 * 10);
+ testFcLayer("csr", 1024 * 10);
}
TEST(Layer, SelectiveFullyConnectedLayer) {
@@ -1995,7 +1995,7 @@ TEST(Layer, multibox_loss) {
TEST(Layer, TransLayer) {
TestConfig config;
const int height = 128;
- const int width = 1028;
+ const int width = 256;
config.layerConfig.set_type("trans");
config.layerConfig.set_size(width);
@@ -2358,6 +2358,38 @@ TEST(Layer, ScaleShiftLayer) {
}
}
+TEST(Layer, ScaleSubRegionLayer) {
+ const size_t batchSize = 64;
+ const size_t size = 4096;
+ TestConfig config;
+ config.layerConfig.set_type("scale_sub_region");
+ config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
+ MatrixPtr indicesV = Matrix::create(batchSize, 6, false, false);
+ auto* data = indicesV->getData();
+ for (size_t i = 0; i < batchSize; ++i) {
+ data[i * 2] = 2;
+ data[i * 2 + 1] = 4;
+ data[i * 2 + 2] = 16;
+ data[i * 2 + 3] = 32;
+ data[i * 2 + 4] = 16;
+ data[i * 2 + 5] = 32;
+ }
+ config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, "indices", indicesV, {}});
+ LayerInputConfig* input = config.layerConfig.add_inputs();
+ ScaleSubRegionConfig* scaleSubRegionConf =
+ input->mutable_scale_sub_region_conf();
+ ImageConfig* imgConf = scaleSubRegionConf->mutable_image_conf();
+ imgConf->set_img_size(32);
+ imgConf->set_img_size_y(32);
+ imgConf->set_channels(4);
+ scaleSubRegionConf->set_value(2.0);
+ config.layerConfig.add_inputs();
+
+ for (auto useGpu : {false, true}) {
+ testLayerGrad(config, "scale_sub_region", batchSize, false, useGpu, false);
+ }
+}
+
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp
index 2e8d9f3333b36005c9b3b28449c76a4a44c74cc6..3960d699ac8dc08316ee413116878ee3eda65793 100644
--- a/paddle/gserver/tests/test_MKLDNN.cpp
+++ b/paddle/gserver/tests/test_MKLDNN.cpp
@@ -300,13 +300,8 @@ void testAddtoLayer(const testImageDesc& pm, const size_t nInputs) {
TestConfig dnnConfig;
getAddtoConfig(dnnConfig, pm, nInputs);
dnnConfig.layerConfig.set_type("mkldnn_addto");
- // TODO(TJ): test with bias
- for (auto withBias : {false}) {
- if (withBias) {
- dnnConfig.biasSize = pm.ic * pm.ih * pm.iw;
- } else {
- dnnConfig.biasSize = 0;
- }
+ for (auto withBias : {false, true}) {
+ dnnConfig.biasSize = withBias ? pm.ic * pm.ih * pm.iw : 0;
RUN_MKLDNN_TEST_LAYER(dnnConfig, "addto", pm)
}
}
diff --git a/paddle/math/tests/TensorCheck.h b/paddle/math/tests/TensorCheck.h
index 5bc4a03067a75527fa30e5bb5526f93dc7b9fdcc..b998e5772e70d0a0ec79dc4064dcbaa2c302efd2 100644
--- a/paddle/math/tests/TensorCheck.h
+++ b/paddle/math/tests/TensorCheck.h
@@ -169,7 +169,7 @@ void TensorCheck(AssertEq compare,
count++;
}
}
- EXPECT_EQ(count, 0) << "There are " << count << " different element.";
+ EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
template
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index eae87a5141ef1284630170b07d22a0cf9cd977b0..29ce44c23308cb5ae1c1df5c9be1412c28abe96f 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -195,8 +195,13 @@ op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
-op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
- DEPS net_op tensor_array)
+if(WITH_TESTING)
+ op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
+ DEPS net_op tensor_array gtest)
+else()
+ op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
+ DEPS net_op tensor_array)
+endif()
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu
index d0c4c0d25d6f4e3ab7acd72d62a8a17fa102637b..1776f33105367447759aa91c25263dfc53bd2f99 100644
--- a/paddle/operators/accuracy_op.cu
+++ b/paddle/operators/accuracy_op.cu
@@ -65,7 +65,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
size_t num_samples = inference->dims()[0];
size_t infer_width = inference->dims()[1];
- cudaMemset((void**)&accuracy_data, 0, sizeof(float));
+ PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
if (num_samples == 0) {
return;
diff --git a/paddle/operators/array_operator.h b/paddle/operators/array_operator.h
new file mode 100644
index 0000000000000000000000000000000000000000..666043e824f885e9c0e79e319d0a38ba108c209a
--- /dev/null
+++ b/paddle/operators/array_operator.h
@@ -0,0 +1,50 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+#include "paddle/framework/lod_tensor_array.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+class ArrayOp : public framework::OperatorBase {
+ public:
+ ArrayOp(const std::string &type, const framework::VariableNameMap &inputs,
+ const framework::VariableNameMap &outputs,
+ const framework::AttributeMap &attrs)
+ : OperatorBase(type, inputs, outputs, attrs) {}
+
+ protected:
+ size_t GetOffset(const framework::Scope &scope,
+ const platform::DeviceContext &dev_ctx) const {
+ auto *i = scope.FindVar(Input("I"));
+ PADDLE_ENFORCE(i != nullptr, "I must be set");
+ auto &i_tensor = i->Get();
+ PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
+ size_t offset;
+ if (platform::is_gpu_place(i_tensor.place())) {
+ // FIXME: Avoid copy from GPU to CPU
+ framework::Tensor t;
+ t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
+ dev_ctx.Wait();
+ offset = static_cast(*t.data());
+ } else {
+ offset = static_cast(*i_tensor.data());
+ }
+ return offset;
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/array_to_lod_tensor_op.cc b/paddle/operators/array_to_lod_tensor_op.cc
index 6cd9c06b8ae3d3b17be83268c2f5d4002705b111..c0903bb4e5ca7f160e19eefab99af7e3e4a8ed76 100644
--- a/paddle/operators/array_to_lod_tensor_op.cc
+++ b/paddle/operators/array_to_lod_tensor_op.cc
@@ -140,6 +140,23 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
"ArrayToLoDTensorOp must has input X.");
PADDLE_ENFORCE(context->HasInput("RankTable"),
"ArrayToLoDTensorOp must has input RankTable.");
+ context->SetOutputDim("Out", context->GetInputDim("X"));
+ }
+};
+
+class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
+ public:
+ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+ std::unique_ptr Apply() const override {
+ auto *grad_op = new framework::OpDescBind();
+ grad_op->SetType("lod_tensor_to_array");
+ grad_op->SetInput("X", OutputGrad("Out"));
+ grad_op->SetInput("RankTable", Input("RankTable"));
+ grad_op->SetOutput("Out", InputGrad("X"));
+ grad_op->SetAttrMap(Attrs());
+ return std::unique_ptr(grad_op);
}
};
@@ -149,4 +166,5 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(array_to_lod_tensor, ops::ArrayToLoDTensorOp,
ops::ArrayToLoDTensorOpProtoMaker,
- ops::ArrayToLoDTensorInferShape);
+ ops::ArrayToLoDTensorInferShape,
+ ops::ArrayToLoDTensorGradMaker);
diff --git a/paddle/operators/clip_by_norm_op.cc b/paddle/operators/clip_by_norm_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..d9fc532e39500fa397be80396b075e866bad9362
--- /dev/null
+++ b/paddle/operators/clip_by_norm_op.cc
@@ -0,0 +1,70 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/operators/clip_by_norm_op.h"
+
+namespace paddle {
+namespace operators {
+
+class ClipByNormOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of ClipByNormOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of ClipByNormOp should not be null.");
+ auto max_norm = ctx->Attrs().Get("max_norm");
+ PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
+ auto x_dims = ctx->GetInputDim("X");
+ ctx->SetOutputDim("Out", x_dims);
+ ctx->ShareLoD("X", /*->*/ "Out");
+ }
+};
+
+class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ ClipByNormOpMaker(framework::OpProto* proto,
+ framework::OpAttrChecker* op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X",
+ "(Tensor) The input of clip_by_norm op."
+ "The number of dimensions must be between [1, 9].");
+ AddOutput("Out",
+ "(Tensor) The output of clip_by_norm op with shape as input(X)");
+ AddAttr("max_norm", "(float) The maximum norm value.");
+ AddComment(R"DOC(
+ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
+If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
+the same as 'X'. If the L2 norm of 'X' is greater than 'max_norm', 'X' will
+be linearly scaled to make the L2 norm of 'Out' equal to 'max_norm', as
+shown in the following formula:
+
+'Out' = 'max_norm' * 'X' / norm('X'),
+
+where norm('X') represents the L2 norm of 'X'.
+)DOC");
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
+ ops::ClipByNormOpMaker);
+REGISTER_OP_CPU_KERNEL(
+ clip_by_norm, ops::ClipByNormKernel);
diff --git a/paddle/operators/fill_constant_op.cu b/paddle/operators/clip_by_norm_op.cu
similarity index 64%
rename from paddle/operators/fill_constant_op.cu
rename to paddle/operators/clip_by_norm_op.cu
index 08c826faadf5ea9ddfe2423f0eb3933bf05f3dd8..2593a24ebbf56ecd286a726e527d2414247576e8 100644
--- a/paddle/operators/fill_constant_op.cu
+++ b/paddle/operators/clip_by_norm_op.cu
@@ -12,12 +12,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
-#include "paddle/framework/op_registry.h"
-#include "paddle/operators/fill_constant_op.h"
+#include "paddle/operators/clip_by_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
- fill_constant, ops::FillConstantOpKernel,
- ops::FillConstantOpKernel,
- ops::FillConstantOpKernel,
- ops::FillConstantOpKernel);
+ clip_by_norm, ops::ClipByNormKernel);
diff --git a/paddle/operators/clip_by_norm_op.h b/paddle/operators/clip_by_norm_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..b26476cae9b5b2fa290bc9186b9a64c48ba703d6
--- /dev/null
+++ b/paddle/operators/clip_by_norm_op.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/transform.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template
+using EigenVector = framework::EigenVector;
+
+template
+class ClipByNormKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto max_norm = context.Attr("max_norm");
+ auto* input = context.Input("X");
+ auto* output = context.Output("Out");
+ output->mutable_data(context.GetPlace());
+
+ auto x = EigenVector::Flatten(*input);
+ auto out = EigenVector::Flatten(*output);
+ auto x_norm = x.square().sum().sqrt();
+ auto place = context.GetEigenDevice();
+
+ auto temp = (x_norm <= max_norm).template cast().eval();
+ auto scaling = temp + (static_cast(1) - temp) * max_norm / x_norm;
+ Eigen::array one_dim{{1}};
+ Eigen::DSizes m_dsize(input->numel());
+ out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc
index 8b425d14df3bc484437dc72f29abf13b887006bd..716b5ee92d0d8737d2069460f53989f691ff7c77 100644
--- a/paddle/operators/compare_op.cc
+++ b/paddle/operators/compare_op.cc
@@ -14,6 +14,7 @@
#include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h"
+
namespace paddle {
namespace operators {
template
@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
}
};
+class CompareOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ framework::OpKernelType GetKernelType(
+ const framework::ExecutionContext &ctx) const override {
+ framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
+ // CompareOp kernel's device type is decided by input tensor place
+ kt.place_ = ctx.Input("X")->place();
+ return kt;
+ }
+};
+
} // namespace operators
} // namespace paddle
-#define REGISTER_LOGICAL_OP(op_type, _equation) \
- struct _##op_type##Comment { \
- static char type[]; \
- static char equation[]; \
- }; \
- char _##op_type##Comment::type[]{#op_type}; \
- char _##op_type##Comment::equation[]{_equation}; \
- REGISTER_OP_WITH_KERNEL( \
- op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
- ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
+#define REGISTER_LOGICAL_OP(op_type, _equation) \
+ struct _##op_type##Comment { \
+ static char type[]; \
+ static char equation[]; \
+ }; \
+ char _##op_type##Comment::type[]{#op_type}; \
+ char _##op_type##Comment::equation[]{_equation}; \
+ REGISTER_OPERATOR( \
+ op_type, ::paddle::operators::CompareOp, \
+ ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
+ ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
diff --git a/paddle/operators/fill_constant_batch_size_like_op.cc b/paddle/operators/fill_constant_batch_size_like_op.cc
index 1019c8c60638ef84e1b09574cdd5600f035d6d83..2f25cc02dfb612e916742a05296fd4d769dc6793 100644
--- a/paddle/operators/fill_constant_batch_size_like_op.cc
+++ b/paddle/operators/fill_constant_batch_size_like_op.cc
@@ -75,10 +75,10 @@ class FillConstantBatchSizeLikeOpMaker
"with the specified value");
AddAttr>("shape", "(vector) The shape of the output");
AddAttr("input_dim_idx",
- "(int, default 0) the index of input's batch size dimension")
+ "(int, default 0) The index of input's batch size dimension")
.SetDefault(0);
AddAttr("output_dim_idx",
- "(int, default 0) the index of output's batch size dimension")
+ "(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
AddAttr("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
@@ -100,5 +100,5 @@ REGISTER_OPERATOR(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
- ops::FillConstantOpKernel,
- ops::FillConstantOpKernel);
+ ops::FillConstantBatchSizeLikeOpKernel,
+ ops::FillConstantBatchSizeLikeOpKernel);
diff --git a/paddle/operators/fill_constant_batch_size_like_op.cu b/paddle/operators/fill_constant_batch_size_like_op.cu
index 33bc3580fd7c9b1ed3620e3010c0c11b02a84f04..565c6fb5b04ab4c259ebe1fc46b3f5b512e28786 100644
--- a/paddle/operators/fill_constant_batch_size_like_op.cu
+++ b/paddle/operators/fill_constant_batch_size_like_op.cu
@@ -18,5 +18,5 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_constant_batch_size_like,
- ops::FillConstantOpKernel,
- ops::FillConstantOpKernel);
+ ops::FillConstantBatchSizeLikeOpKernel,
+ ops::FillConstantBatchSizeLikeOpKernel);
diff --git a/paddle/operators/fill_constant_op.h b/paddle/operators/fill_constant_batch_size_like_op.h
similarity index 94%
rename from paddle/operators/fill_constant_op.h
rename to paddle/operators/fill_constant_batch_size_like_op.h
index 48f4d9ac4c4d8fd500b62bb7cb127b1f8d6e5db6..ea184e6b979a48d39edee17c92eb7a1945fa38e9 100644
--- a/paddle/operators/fill_constant_op.h
+++ b/paddle/operators/fill_constant_batch_size_like_op.h
@@ -21,7 +21,7 @@ namespace paddle {
namespace operators {
template
-class FillConstantOpKernel : public framework::OpKernel {
+class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output("Out");
diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc
index 5a1cba51f83bb8577bc94ae23d1a44bb801ae4c7..818f113b90a4c239a857791fb9957e51d3287b97 100644
--- a/paddle/operators/fill_constant_op.cc
+++ b/paddle/operators/fill_constant_op.cc
@@ -12,33 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
-#include "paddle/operators/fill_constant_op.h"
+#include "paddle/framework/data_type.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
-class FillConstantOp : public framework::OperatorWithKernel {
+class FillConstantInferShape : public framework::InferShapeBase {
public:
- using framework::OperatorWithKernel::OperatorWithKernel;
-
- void InferShape(framework::InferShapeContext *ctx) const override {
+ void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillConstantOp should not be null.");
auto &shape = ctx->Attrs().Get>("shape");
- std::vector shape_int64(shape.size(), 0);
- std::transform(shape.begin(), shape.end(), shape_int64.begin(),
- [](int a) { return static_cast(a); });
- auto dims = framework::make_ddim(shape_int64);
- ctx->SetOutputDim("Out", dims);
+ ctx->SetOutputDim("Out", framework::make_ddim(shape));
}
+};
- protected:
- framework::OpKernelType GetKernelType(
- const framework::ExecutionContext &ctx) const override {
- int data_type = ctx.Attr("data_type");
- VLOG(10) << " FillConstant data_type = " << data_type;
- return framework::OpKernelType(static_cast(data_type),
- ctx.device_context());
+class FillConstantOp : public framework::OperatorBase {
+ public:
+ using framework::OperatorBase::OperatorBase;
+ void Run(const framework::Scope &scope,
+ const platform::DeviceContext &dev_ctx) const override {
+ auto data_type = static_cast(Attr("data_type"));
+ auto value = Attr("value");
+ auto force_cpu = Attr("force_cpu");
+ auto &out =
+ *scope.FindVar(Output("Out"))->GetMutable();
+ out.Resize(framework::make_ddim(Attr>("shape")));
+ if (force_cpu) {
+ auto cpu = platform::CPUPlace();
+ out.mutable_data(cpu, framework::ToTypeIndex(data_type));
+ } else {
+ out.mutable_data(dev_ctx.GetPlace(), framework::ToTypeIndex(data_type));
+ }
+ math::set_constant(dev_ctx, &out, value);
}
};
@@ -54,6 +62,11 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr>("shape", "(vector) The shape of the output");
AddAttr("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
+ AddAttr("force_cpu",
+ "(bool, default false) Force fill output variable to cpu "
+ "memory. Otherwise, fill output variable to the running "
+ "device")
+ .SetDefault(false);
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
@@ -69,10 +82,6 @@ Fill up a variable with specified constant value.
} // namespace paddle
namespace ops = paddle::operators;
-REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
- ops::FillConstantOpMaker);
-REGISTER_OP_CPU_KERNEL(
- fill_constant, ops::FillConstantOpKernel,
- ops::FillConstantOpKernel,
- ops::FillConstantOpKernel,
- ops::FillConstantOpKernel);
+REGISTER_OPERATOR(fill_constant, ops::FillConstantOp,
+ ops::FillConstantInferShape, ops::FillConstantOpMaker,
+ paddle::framework::EmptyGradOpMaker);
diff --git a/paddle/operators/lod_tensor_to_array_op.cc b/paddle/operators/lod_tensor_to_array_op.cc
index 5f02f5e8a12831a33683cdc53cf0feb7cb908da5..58af35564d83b9699af4f7783fb6367ff9590682 100644
--- a/paddle/operators/lod_tensor_to_array_op.cc
+++ b/paddle/operators/lod_tensor_to_array_op.cc
@@ -133,6 +133,22 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
}
};
+class LoDTensorToArrayGradMaker : public framework::SingleGradOpDescMaker {
+ public:
+ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+ std::unique_ptr Apply() const override {
+ auto *grad_op = new framework::OpDescBind();
+ grad_op->SetType("array_to_lod_tensor");
+ grad_op->SetInput("X", OutputGrad("Out"));
+ grad_op->SetInput("RankTable", Input("RankTable"));
+ grad_op->SetOutput("Out", InputGrad("X"));
+ grad_op->SetAttrMap(Attrs());
+ return std::unique_ptr(grad_op);
+ }
+};
+
} // namespace operators
} // namespace paddle
@@ -140,4 +156,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(lod_tensor_to_array, ops::LoDTensorToArrayOp,
ops::LoDTensorToArrayOpProtoMaker,
ops::LoDTensorToArrayInferShape,
- ops::LoDTensorToArrayInferVarType);
+ ops::LoDTensorToArrayInferVarType,
+ ops::LoDTensorToArrayGradMaker);
diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc
index f4519ec16f3f694cf49941f8d23c4106f6f1ddc3..18b9cdf2a39e8226c634194ff2cc56d169979774 100644
--- a/paddle/operators/lstm_unit_op.cc
+++ b/paddle/operators/lstm_unit_op.cc
@@ -34,10 +34,10 @@ class LstmUnitOp : public framework::OperatorWithKernel {
auto c_prev_dims = ctx->GetInputDim("C_prev");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
- PADDLE_ENFORCE(x_dims[0] == c_prev_dims[0],
- "Batch size of inputs and states must be equal");
- PADDLE_ENFORCE(x_dims[1] == c_prev_dims[1] * 4,
- "Dimension of FC should equal to prev state * 4");
+ PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
+ "Batch size of inputs and states must be equal");
+ PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
+ "Dimension of FC should equal to prev state * 4");
int b_size = c_prev_dims[0]; // batch size
int s_dim = c_prev_dims[1]; // state dim
diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc
index 2a9c09a0f16b71473e21765ab9253eb7b8bcf28c..09c3f0b1e6f787547b9253d3aeadf70674708ba0 100644
--- a/paddle/operators/math/math_function.cc
+++ b/paddle/operators/math/math_function.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
+#include "paddle/framework/data_type.h"
namespace paddle {
namespace operators {
@@ -233,6 +234,52 @@ void gemv(const platform::DeviceContext& context,
template struct SetConstant;
+struct TensorSetConstant {
+ TensorSetConstant(framework::Tensor* tensor, float value)
+ : tensor_(tensor), value_(value) {}
+ template
+ void operator()() const {
+ auto cpu = platform::CPUPlace();
+ auto* begin = tensor_->mutable_data(cpu);
+ std::fill(begin, begin + tensor_->numel(), static_cast(value_));
+ }
+ framework::Tensor* tensor_;
+ float value_;
+};
+
+template <>
+void set_constant_with_place(
+ const platform::DeviceContext& context, framework::Tensor* tensor,
+ float value) {
+ framework::VisitDataType(framework::ToDataType(tensor->type()),
+ TensorSetConstant(tensor, value));
+}
+
+struct TensorSetConstantWithPlace : public boost::static_visitor {
+ TensorSetConstantWithPlace(const platform::DeviceContext& context,
+ framework::Tensor* tensor, float value)
+ : context_(context), tensor_(tensor), value_(value) {}
+
+ template
+ void operator()(Place place) const {
+ set_constant_with_place(context_, tensor_, value_);
+ }
+
+ const platform::DeviceContext& context_;
+ framework::Tensor* tensor_;
+ float value_;
+};
+
+void set_constant(const platform::DeviceContext& context,
+ framework::Tensor* tensor, float value) {
+ TensorSetConstantWithPlace func(context, tensor, value);
+#ifdef PADDLE_WITH_CUDA
+ tensor->place().apply_visitor(func);
+#else
+ func(platform::CPUPlace());
+#endif
+}
+
} // namespace math
} // namespace operators
} // namespace paddle
diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu
index e6fd8bf235b8539702ca2c5b39e305cb1becf5cb..255e480680499877ff599b96b8336a968cccbb34 100644
--- a/paddle/operators/math/math_function.cu
+++ b/paddle/operators/math/math_function.cu
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#include "paddle/framework/data_type.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
@@ -232,6 +233,30 @@ void gemv(const platform::DeviceContext& context,
template struct SetConstant;
+struct TensorSetConstant {
+ TensorSetConstant(const platform::DeviceContext& context,
+ framework::Tensor* tensor, float value)
+ : context_(context), tensor_(tensor), value_(value) {}
+
+ template
+ void operator()() const {
+ SetConstant functor;
+ functor(context_, tensor_, static_cast(value_));
+ }
+
+ const platform::DeviceContext& context_;
+ framework::Tensor* tensor_;
+ float value_;
+};
+
+template <>
+void set_constant_with_place(
+ const platform::DeviceContext& context, framework::Tensor* tensor,
+ float value) {
+ framework::VisitDataType(framework::ToDataType(tensor->type()),
+ TensorSetConstant(context, tensor, value));
+}
+
} // namespace math
} // namespace operators
} // namespace paddle
diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h
index 3bb5aa0332c7e2a63d20b91893c03ccd468dd863..1c9eabb2b78f1d69054b347f27854ee8ca3f3d1e 100644
--- a/paddle/operators/math/math_function.h
+++ b/paddle/operators/math/math_function.h
@@ -108,6 +108,13 @@ struct SetConstant {
}
};
+template
+void set_constant_with_place(const platform::DeviceContext& context,
+ framework::Tensor* tensor, float value);
+
+void set_constant(const platform::DeviceContext& context,
+ framework::Tensor* tensor, float value);
+
} // namespace math
} // namespace operators
} // namespace paddle
diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc
index 7d84ad9aadb2892db0d0ee9cab428dc5036614e9..983c9fdcffb0a67da1bc0b5b4af9420a68bd2ac1 100644
--- a/paddle/operators/math/math_function_test.cc
+++ b/paddle/operators/math/math_function_test.cc
@@ -139,3 +139,15 @@ TEST(math_function, gemv) {
GemvTest(12, 7, true);
GemvTest(7, 9, true);
}
+
+TEST(math_funciton, set_constant) {
+ paddle::framework::Tensor t;
+ t.Resize({10, 10});
+ t.mutable_data(paddle::platform::CPUPlace());
+ auto* ctx = new paddle::platform::CPUDeviceContext();
+ paddle::operators::math::set_constant(*ctx, &t, 10);
+ for (int64_t i = 0; i < t.numel(); ++i) {
+ PADDLE_ENFORCE_EQ(10, t.data()[i]);
+ }
+ delete ctx;
+}
diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc
index 78b4bbca84d4670aba73222f1d679604d7516b02..dcc5b4286f4ac833268a779a9a7edd2ed119ffff 100644
--- a/paddle/operators/mean_op.cc
+++ b/paddle/operators/mean_op.cc
@@ -51,6 +51,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
+ ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu
index 8d0741dccc1fdae069af55da49f44378e2c4ddf8..8711567b95fea355396173b5312d26d31f9ffb12 100644
--- a/paddle/operators/pool_cudnn_op.cu
+++ b/paddle/operators/pool_cudnn_op.cu
@@ -37,11 +37,11 @@ class PoolCudnnOpKernel : public framework::OpKernel {
const T *input_data = input->data();
T *output_data = output->mutable_data(ctx.GetPlace());
- std::string pooling_type = ctx.Attr("poolingType");
+ std::string pooling_type = ctx.Attr("pooling_type");
std::vector ksize = ctx.Attr>("ksize");
std::vector strides = ctx.Attr>("strides");
std::vector paddings = ctx.Attr>("paddings");
- if (ctx.Attr("globalPooling")) {
+ if (ctx.Attr("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast(input->dims()[i + 2]);
@@ -92,12 +92,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel {
ctx.Input(framework::GradVarName("Out"));
Tensor *input_grad = ctx.Output(framework::GradVarName("X"));
- std::string pooling_type = ctx.Attr("poolingType");
+ std::string pooling_type = ctx.Attr("pooling_type");
std::vector ksize = ctx.Attr>("ksize");
std::vector strides = ctx.Attr>("strides");
std::vector paddings = ctx.Attr>("paddings");
- if (ctx.Attr("globalPooling")) {
+ if (ctx.Attr("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast(input->dims()[i + 2]);
diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc
index f58aab733866973f477ca79e5e53ba58adbf3dc7..f3963b1995ef8767786f0bf230b134afc69aa99d 100644
--- a/paddle/operators/pool_op.cc
+++ b/paddle/operators/pool_op.cc
@@ -29,7 +29,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
auto in_x_dims = ctx->GetInputDim("X");
- std::string pooling_type = ctx->Attrs().Get("poolingType");
+ std::string pooling_type = ctx->Attrs().Get("pooling_type");
std::vector ksize = ctx->Attrs().Get>("ksize");
std::vector strides = ctx->Attrs().Get>("strides");
std::vector paddings = ctx->Attrs().Get>("paddings");
@@ -37,7 +37,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor.");
- if (ctx->Attrs().Get("globalPooling")) {
+ if (ctx->Attrs().Get("global_pooling")) {
ksize.resize(static_cast(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
@@ -83,20 +83,20 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"H is the height of the feature, "
"and W is the width of the feature.");
- AddAttr("poolingType",
+ AddAttr("pooling_type",
"(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.")
.InEnum({"max", "avg"});
AddAttr>("ksize",
"(vector) The pooling window "
"size(height, width) of the pooling operator. "
- "If globalPooling = true, ksize and paddings will "
+ "If global_pooling = true, ksize and paddings will "
"be ignored."); // TODO(Chengduo): Add checker.
// (Currently,
// TypedAttrChecker don't support vector type.)
- AddAttr("globalPooling",
+ AddAttr("global_pooling",
"(bool, default false) Whether to use the global pooling. "
- "If globalPooling = true, ksize and paddings will be ignored.")
+ "If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault(false);
AddAttr>("strides",
"(vector, default {1, 1}), strides(height, "
@@ -107,7 +107,7 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"paddings",
"(vector, defalut {0,0}), paddings(height, width) of pooling "
"operator."
- "If globalPooling = true, paddings and ksize will be ignored.")
+ "If global_pooling = true, paddings and ksize will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
@@ -115,7 +115,7 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
Pool2d Operator.
The pooling2d operation calculates the output based on
-the input, poolingType and ksize, strides, paddings parameters.
+the input, pooling_type and ksize, strides, paddings parameters.
Input(X) and output(Out) are in NCHW format, where N is batch size, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
Parameters(ksize, strides, paddings) are two elements.
@@ -152,7 +152,7 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively.");
- AddAttr("poolingType",
+ AddAttr("pooling_type",
"(string) Pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.")
.InEnum({"max", "avg"});
@@ -160,13 +160,14 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"ksize",
"(vector) The pooling window size(depth, height, "
"width) of pooling operator. "
- "If globalPooling = true, ksize and paddings will "
+ "If global_pooling = true, ksize and paddings will "
"be ignored."); // TODO(Chengduo): Add checker.
// (Currently,
// TypedAttrChecker don't support vector type.)
- AddAttr("globalPooling",
- "(bool, default false) Whether to use the global pooling. "
- "If globalPooling = true, ksize and paddings wille be ignored.")
+ AddAttr(
+ "global_pooling",
+ "(bool, default false) Whether to use the global pooling. "
+ "If global_pooling = true, ksize and paddings wille be ignored.")
.SetDefault(false);
AddAttr>(
"strides",
@@ -178,7 +179,7 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"paddings",
"(vector, defalut {0,0,0}), paddings(depth, height, "
"width) of pooling operator. "
- "If globalPooling = true, ksize and paddings will be ignored.")
+ "If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
@@ -186,7 +187,7 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
Pool3d Operator.
The pooling3d operation calculates the output based on
-the input, poolingType, ksize, strides, and paddings parameters.
+the input, pooling_type, ksize, strides, and paddings parameters.
Input(X) and output(Out) are in NCDHW format, where N is batch
size, C is the number of channels, and D, H and W are the depth, height and
width of the feature, respectively. Parameters(ksize, strides, paddings)
diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h
index d9d445f6a6257b0c8a1959c64c9a878539e10cd4..4da1941ab541483e706257667b14aa5a95e0c3cc 100644
--- a/paddle/operators/pool_op.h
+++ b/paddle/operators/pool_op.h
@@ -57,11 +57,11 @@ class PoolKernel : public framework::OpKernel {
const Tensor* in_x = context.Input("X");
Tensor* out = context.Output("Out");
- std::string pooling_type = context.Attr("poolingType");
+ std::string pooling_type = context.Attr("pooling_type");
std::vector ksize = context.Attr>("ksize");
std::vector strides = context.Attr>("strides");
std::vector paddings = context.Attr>("paddings");
- if (context.Attr("globalPooling")) {
+ if (context.Attr("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast(in_x->dims()[i + 2]);
@@ -119,12 +119,12 @@ class PoolGradKernel : public framework::OpKernel {
context.Input(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output(framework::GradVarName("X"));
- std::string pooling_type = context.Attr("poolingType");
+ std::string pooling_type = context.Attr("pooling_type");
std::vector ksize = context.Attr