diff --git a/benchmark/paddle/image/vgg.py b/benchmark/paddle/image/vgg.py
index b8429975f5c83df6996e71478fe276b246e8b77b..420884ed8e1ae36a3f1772bfbe8323f3d0ea71e6 100644
--- a/benchmark/paddle/image/vgg.py
+++ b/benchmark/paddle/image/vgg.py
@@ -13,7 +13,7 @@ define_py_data_sources2(
 
 settings(
     batch_size=batch_size,
-    learning_rate=0.01 / batch_size,
+    learning_rate=0.001 / batch_size,
     learning_method=MomentumOptimizer(0.9),
     regularization=L2Regularization(0.0005 * batch_size))
 
diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake
index 8fdc382f0c1c453a01dba884a3dad216e1c3092c..b21fc43904d9aafe9f7d019dfbe5b1c0d3f9e2d6 100644
--- a/cmake/cblas.cmake
+++ b/cmake/cblas.cmake
@@ -1,17 +1,12 @@
 # Find the CBlas and lapack libraries
 #
-# It will search MKL, atlas, OpenBlas, reference-cblas in order.
+# It will search MKLML, atlas, OpenBlas, reference-cblas in order.
 #
 # If any cblas implementation found, the following variable will be set.
-#    CBLAS_PROVIDER  # one of MKL, ATLAS, OPENBLAS, REFERENCE
+#    CBLAS_PROVIDER  # one of MKLML, ATLAS, OPENBLAS, REFERENCE
 #    CBLAS_INC_DIR   # the include directory for cblas.
 #    CBLAS_LIBS      # a list of libraries should be linked by paddle.
 #                    # Each library should be full path to object file.
-#
-# User should set one of MKL_ROOT, ATLAS_ROOT, OPENBLAS_ROOT, REFERENCE_CBLAS_ROOT
-# during cmake. If none of them set, it will try to find cblas implementation in
-# system paths.
-#
 
 set(CBLAS_FOUND OFF)
 
@@ -30,44 +25,6 @@ if(WITH_MKLML AND MKLML_INC_DIR AND MKLML_LIB)
   return()
 endif()
 
-## Then find MKL.
-set(INTEL_MKL_ROOT "/opt/intel/mkl" CACHE PATH "Folder contains intel mkl libs")
-set(MKL_ROOT $ENV{MKL_ROOT} CACHE PATH "Folder contains env MKL")
-
-set(MKL_INCLUDE_SEARCH_PATHS
-  ${MKL_ROOT}/include
-  ${INTEL_MKL_ROOT}/include)
-set(MKL_LIB_SEARCH_PATHS
-  ${MKL_ROOT}/lib
-  ${MKL_ROOT}/lib/intel64
-  ${INTEL_MKL_ROOT}/lib
-  ${INTEL_MKL_ROOT}/lib/intel64)
-
-find_path(MKL_INC_DIR mkl.h PATHS
-  ${MKL_INCLUDE_SEARCH_PATHS})
-find_path(MKL_LAPACK_INC_DIR mkl_lapacke.h PATHS
-  ${MKL_INCLUDE_SEARCH_PATHS})
-find_library(MKL_CORE_LIB NAMES mkl_core PATHS
-  ${MKL_LIB_SEARCH_PATHS})
-find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS
-  ${MKL_LIB_SEARCH_PATHS})
-find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS
-  ${MKL_LIB_SEARCH_PATHS})
-
-if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64)
-  set(CBLAS_FOUND ON)
-  set(CBLAS_PROVIDER MKL)
-  set(CBLAS_INC_DIR ${MKL_INC_DIR} ${MKL_LAPACK_INC_DIR})
-  set(CBLAS_LIBRARIES ${MKL_INTEL_LP64} ${MKL_SEQUENTIAL_LIB} ${MKL_CORE_LIB})
-
-  add_definitions(-DPADDLE_USE_MKL)
-  add_definitions(-DLAPACK_FOUND)
-
-  message(STATUS "Found MKL (include: ${MKL_INC_DIR}, library: ${CBLAS_LIBRARIES})")
-  message(STATUS "Found lapack in MKL (include: ${MKL_LAPACK_INC_DIR})")
-  return()
-endif()
-
 ## Then find atlas.
 set(ATLAS_ROOT $ENV{ATLAS_ROOT} CACHE PATH "Folder contains Atlas")
 set(ATLAS_INCLUDE_SEARCH_PATHS
diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake
index 9686df00219001769d074ee815d9cc8db0258496..5a06825beb73e85d8a55b7b578b187bee2c4340c 100644
--- a/cmake/external/mkldnn.cmake
+++ b/cmake/external/mkldnn.cmake
@@ -46,16 +46,20 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
     MESSAGE(STATUS "Build MKLDNN with ${MKLDNN_MKLROOT}")
 ENDIF()
 
+SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} -Wno-error=strict-overflow")
+SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} -Wno-error=strict-overflow")
 ExternalProject_Add(
     ${MKLDNN_PROJECT}
     ${EXTERNAL_PROJECT_LOG_ARGS}
     DEPENDS             ${MKLDNN_DEPENDS}
     GIT_REPOSITORY      "https://github.com/01org/mkl-dnn.git"
-    GIT_TAG             "v0.10"
+    GIT_TAG             "v0.11"
     PREFIX              ${MKLDNN_SOURCES_DIR}
     UPDATE_COMMAND      ""
     CMAKE_ARGS          -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
     CMAKE_ARGS          -DMKLROOT=${MKLDNN_MKLROOT}
+    CMAKE_ARGS          -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
+    CMAKE_ARGS          -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
     CMAKE_CACHE_ARGS    -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
                         -DMKLROOT:PATH=${MKLDNN_MKLROOT}
 )
diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake
index 74f3279831357c21038df133df0f5a432a6dfd20..20dbc32a738d982df2d3f035206279c82c8de264 100644
--- a/cmake/external/mklml.cmake
+++ b/cmake/external/mklml.cmake
@@ -27,8 +27,8 @@ ENDIF()
 INCLUDE(ExternalProject)
 
 SET(MKLML_PROJECT       "extern_mklml")
-SET(MKLML_VER           "mklml_lnx_2018.0.20170720")
-SET(MKLML_URL           "https://github.com/01org/mkl-dnn/releases/download/v0.10/${MKLML_VER}.tgz")
+SET(MKLML_VER           "mklml_lnx_2018.0.1.20171007")
+SET(MKLML_URL           "https://github.com/01org/mkl-dnn/releases/download/v0.11/${MKLML_VER}.tgz")
 SET(MKLML_SOURCE_DIR    "${THIRD_PARTY_PATH}/mklml")
 SET(MKLML_DOWNLOAD_DIR  "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
 SET(MKLML_DST_DIR       "mklml")
diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake
index 3f86e456cfbe55fe47e5b18e755e34829ebe9930..42ffd6cf347ea54ce059519770629761ae0256ee 100644
--- a/cmake/external/openblas.cmake
+++ b/cmake/external/openblas.cmake
@@ -115,7 +115,7 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
 # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
 SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
 FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
-IF(${CBLAS_PROVIDER} MATCHES MKL)
+IF(${CBLAS_PROVIDER} EQUAL MKLML)
     ADD_LIBRARY(cblas SHARED ${dummyfile})
 ELSE()
     ADD_LIBRARY(cblas STATIC ${dummyfile})
diff --git a/doc/api/v2/data.rst b/doc/api/v2/data.rst
index fef87c4fbdb452771ecdb361c6eeae5b32bcee14..b56c7332cc284649c7e04328e51a7faa78593a39 100644
--- a/doc/api/v2/data.rst
+++ b/doc/api/v2/data.rst
@@ -2,112 +2,9 @@
 Data Reader Interface and DataSets
 ==================================
 
+..  toctree::
+    :maxdepth: 1
 
-DataTypes
-=========
-
-..  automodule:: paddle.v2.data_type
-    :members:
-    :noindex:
-
-DataFeeder
-==========
-
-..  automodule:: paddle.v2.data_feeder
-    :members:
-    :noindex:
-
-Reader
-======
-
-..  automodule:: paddle.v2.reader
-    :members:
-    :noindex:
-
-..  automodule:: paddle.v2.reader.creator
-    :members:
-    :noindex:
-
-minibatch
-=========
-
-..  automodule:: paddle.v2.minibatch
-    :members:
-    :noindex:
-
-Dataset
-=======
-
-..  automodule:: paddle.v2.dataset
-    :members:
-    :noindex:
-
-mnist
-+++++
-
-..  automodule:: paddle.v2.dataset.mnist
-    :members:
-    :noindex:
-
-cifar
-+++++
-
-..  automodule:: paddle.v2.dataset.cifar
-    :members:
-    :noindex:
-
-conll05
-+++++++
-
-..  automodule:: paddle.v2.dataset.conll05
-    :members: get_dict,get_embedding,test
-    :noindex:
-
-imdb
-++++
-
-..  automodule:: paddle.v2.dataset.imdb
-    :members:
-    :noindex:
-
-imikolov
-++++++++
-
-..  automodule:: paddle.v2.dataset.imikolov
-    :members:
-    :noindex:
-
-movielens
-+++++++++
-
-..  automodule:: paddle.v2.dataset.movielens
-    :members:
-    :noindex:
-
-..  autoclass:: paddle.v2.dataset.movielens.MovieInfo
-    :noindex:
-    
-..  autoclass:: paddle.v2.dataset.movielens.UserInfo
-    :noindex:
-
-sentiment
-+++++++++
-
-..  automodule:: paddle.v2.dataset.sentiment
-    :members:
-    :noindex:
-
-uci_housing
-+++++++++++
-
-..  automodule:: paddle.v2.dataset.uci_housing
-    :members:
-    :noindex:
-
-wmt14
-+++++
-
-..  automodule:: paddle.v2.dataset.wmt14
-    :members:
-    :noindex:
-
+    data/data_reader.rst
+    data/image.rst
+    data/dataset.rst
diff --git a/doc/api/v2/data/data_reader.rst b/doc/api/v2/data/data_reader.rst
new file mode 100644
index 0000000000000000000000000000000000000000..2ccfec9c284877a7576e9751526b169a4ac78d8e
--- /dev/null
+++ b/doc/api/v2/data/data_reader.rst
@@ -0,0 +1,36 @@
+=====================
+Data Reader Interface
+=====================
+
+
+DataTypes
+=========
+
+..  automodule:: paddle.v2.data_type
+    :members:
+    :noindex:
+
+DataFeeder
+==========
+
+..  automodule:: paddle.v2.data_feeder
+    :members:
+    :noindex:
+
+Reader
+======
+
+..  automodule:: paddle.v2.reader
+    :members:
+    :noindex:
+
+..  automodule:: paddle.v2.reader.creator
+    :members:
+    :noindex:
+
+minibatch
+=========
+
+..  automodule:: paddle.v2.minibatch
+    :members:
+    :noindex:
diff --git a/doc/api/v2/data/dataset.rst b/doc/api/v2/data/dataset.rst
new file mode 100644
index 0000000000000000000000000000000000000000..6a8ecc5bb1d855e0ded3719943ab3adb810de365
--- /dev/null
+++ b/doc/api/v2/data/dataset.rst
@@ -0,0 +1,75 @@
+Dataset
+=======
+
+..  automodule:: paddle.v2.dataset
+    :members:
+    :noindex:
+
+mnist
++++++
+
+..  automodule:: paddle.v2.dataset.mnist
+    :members:
+    :noindex:
+
+cifar
++++++
+
+..  automodule:: paddle.v2.dataset.cifar
+    :members:
+    :noindex:
+
+conll05
++++++++
+
+..  automodule:: paddle.v2.dataset.conll05
+    :members: get_dict,get_embedding,test
+    :noindex:
+
+imdb
+++++
+
+..  automodule:: paddle.v2.dataset.imdb
+    :members:
+    :noindex:
+
+imikolov
+++++++++
+
+..  automodule:: paddle.v2.dataset.imikolov
+    :members:
+    :noindex:
+
+movielens
++++++++++
+
+..  automodule:: paddle.v2.dataset.movielens
+    :members:
+    :noindex:
+
+..  autoclass:: paddle.v2.dataset.movielens.MovieInfo
+    :noindex:
+    
+..  autoclass:: paddle.v2.dataset.movielens.UserInfo
+    :noindex:
+
+sentiment
++++++++++
+
+..  automodule:: paddle.v2.dataset.sentiment
+    :members:
+    :noindex:
+
+uci_housing
++++++++++++
+
+..  automodule:: paddle.v2.dataset.uci_housing
+    :members:
+    :noindex:
+
+wmt14
++++++
+
+..  automodule:: paddle.v2.dataset.wmt14
+    :members:
+    :noindex:
diff --git a/doc/api/v2/data/image.rst b/doc/api/v2/data/image.rst
new file mode 100644
index 0000000000000000000000000000000000000000..97651ffa6be56cf3ecaca2caca38a353fa5c1f49
--- /dev/null
+++ b/doc/api/v2/data/image.rst
@@ -0,0 +1,5 @@
+Image Interface
+===============
+
+..  automodule:: paddle.v2.image
+    :members:
diff --git a/doc/design/ops/images/LOD-and-shape-changes-during-decoding.jpg b/doc/design/ops/images/LOD-and-shape-changes-during-decoding.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8b0d90f7b9d8184b314b0ee4e521f53eb5f1b455
Binary files /dev/null and b/doc/design/ops/images/LOD-and-shape-changes-during-decoding.jpg differ
diff --git a/doc/design/ops/sequence_decoder.md b/doc/design/ops/sequence_decoder.md
new file mode 100644
index 0000000000000000000000000000000000000000..9007aae7a8355ed06c6720a921351f81b859c1fe
--- /dev/null
+++ b/doc/design/ops/sequence_decoder.md
@@ -0,0 +1,245 @@
+# Design: Sequence Decoder Generating LoDTensors
+In tasks such as machine translation and image to text, 
+a [sequence decoder](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.md) is necessary to generate sequences.
+
+This documentation describes how to implement the sequence decoder as an operator.
+
+## Beam Search based Decoder
+The [beam search algorithm](https://en.wikipedia.org/wiki/Beam_search) is necessary when generating sequences, 
+it is a heuristic search algorithm that explores the paths by expanding the most promising node in a limited set.
+
+In the old version of PaddlePaddle, a C++ class `RecurrentGradientMachine` implements the general sequence decoder based on beam search, 
+due to the complexity, the implementation relays on a lot of special data structures, 
+quite trivial and hard to be customized by users.
+
+There are a lot of heuristic tricks in the sequence generation tasks, 
+so the flexibility of sequence decoder is very important to users.
+
+During PaddlePaddle's refactoring work,
+some new concept is proposed such as [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) and [TensorArray](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/tensor_array.md) that can better support sequence usage,
+and they can help to make the implementation of beam search based sequence decoder **more transparent and modular** .
+
+For example, the RNN sates, candidates IDs and probabilities of beam search can be represented as `LoDTensors`;
+the selected candidate's IDs in each time step can be stored in a `TensorArray`, and `Packed` to the sentences translated.
+
+## Changing LoD's absolute offset to relative offsets
+The current `LoDTensor` is designed to store levels of variable-length sequences,
+it stores several arrays of integers each represents a level.
+
+The integers in each level represents the begin and end (not inclusive) offset of a sequence **in the underlying tensor**, 
+let's call this format the **absolute-offset LoD** for clear.
+
+The relative-offset LoD can fast retrieve any sequence but fails to represent empty sequences, for example, a two-level LoD is as follows
+```python
+[[0, 3, 9]
+ [0, 2, 3, 3, 3, 9]]
+```
+The first level tells that there are two sequences:
+- the first's offset is `[0, 3)`
+- the second's offset is `[3, 9)`
+
+while on the second level, there are several empty sequences that both begin and end at `3`.
+It is impossible to tell how many empty second-level sequences exist in the first-level sequences.
+
+There are many scenarios that relay on empty sequence representation,
+such as machine translation or image to text, one instance has no translations or the empty candidate set for a prefix.
+
+So let's introduce another format of LoD, 
+it stores **the offsets of the lower level sequences** and is called **relative-offset** LoD.
+
+For example, to represent the same sequences of the above data
+
+```python
+[[0, 3, 6]
+ [0, 2, 3, 3, 3, 9]]
+```
+
+the first level represents that there are two sequences, 
+their offsets in the second-level LoD is `[0, 3)` and `[3, 5)`.
+
+The second level is the same with the relative offset example because the lower level is a tensor.
+It is easy to find out the second sequence in the first-level LoD has two empty sequences.
+
+The following demos are based on relative-offset LoD.
+
+## Usage in a simple machine translation model
+Let's start from a simple machine translation model that is simplified from [machine translation chapter](https://github.com/PaddlePaddle/book/tree/develop/08.machine_translation) to draw a simple blueprint of what a sequence decoder can do and how to use it.
+
+The model has an encoder that learns the semantic vector from a sequence,
+and a decoder which uses the sequence decoder to generate new sentences.
+
+**Encoder**
+```python
+import paddle as pd
+
+dict_size = 8000
+source_dict_size = dict_size
+target_dict_size = dict_size
+word_vector_dim = 128
+encoder_dim = 128
+decoder_dim = 128
+beam_size = 5
+max_length = 120
+
+# encoder
+src_word_id = pd.data(
+    name='source_language_word',
+    type=pd.data.integer_value_sequence(source_dict_dim))
+src_embedding = pd.embedding(size=source_dict_size, size=word_vector_dim)
+
+src_word_vec = pd.lookup(src_embedding, src_word_id)
+
+encoder_out_seq = pd.gru(input=src_word_vec, size=encoder_dim)
+
+encoder_ctx = pd.last_seq(encoder_out_seq)
+# encoder_ctx_proj is the learned semantic vector
+encoder_ctx_proj = pd.fc(
+    encoder_ctx, size=decoder_dim, act=pd.activation.Tanh(), bias=None)
+```
+
+**Decoder**
+
+```python
+def generate():
+    decoder = pd.while_loop()
+    with decoder.step():
+        decoder_mem = decoder.memory(init=encoder_ctx)  # mark the memory
+        generated_ids = decoder.memory() # TODO init to batch_size s
+        generated_scores = decoder.memory() # TODO init to batch_size 1s or 0s
+
+        target_word = pd.lookup(trg_embedding, gendrated_ids)
+        # expand encoder_ctx's batch to fit target_word's lod
+        # for example
+        # decoder_mem.lod is
+        # [[0 1 3],
+        #  [0 1 3 6]]
+        # its tensor content is [a1 a2 a3 a4 a5]
+        # which means there are 2 sentences to translate
+        #   - the first sentence has 1 translation prefixes, the offsets are [0, 1)
+        #   - the second sentence has 2 translation prefixes, the offsets are [1, 3) and [3, 6)
+        # the target_word.lod is 
+        # [[0, 1, 6]
+        #  [0, 2, 4, 7, 9 12]]
+        # which means 2 sentences to translate, each has 1 and 5 prefixes
+        # the first prefix has 2 candidates
+        # the following has 2, 3, 2, 3 candidates
+        # the encoder_ctx_expanded's content will be
+        # [a1 a1 a2 a2 a3 a3 a3 a4 a4 a5 a5 a5]
+        encoder_ctx_expanded = pd.lod_expand(encoder_ctx, target_word)
+        decoder_input = pd.fc(
+            act=pd.activation.Linear(),
+            input=[target_word, encoder_ctx],
+            size=3 * decoder_dim)
+        gru_out, cur_mem = pd.gru_step(
+            decoder_input, mem=decoder_mem, size=decoder_dim)
+        scores = pd.fc(
+            gru_out,
+            size=trg_dic_size,
+            bias=None,
+            act=pd.activation.Softmax())
+        # K is an config
+        topk_scores, topk_ids = pd.top_k(scores, K)
+        topk_generated_scores = pd.add_scalar(topk_scores, generated_scores)
+
+        selected_ids, selected_generation_scores = decoder.beam_search(
+            topk_ids, topk_generated_scores)
+
+        # update the states
+        decoder_mem.update(cur_mem)  # tells how to update state
+        generated_ids.update(selected_ids)
+        generated_scores.update(selected_generation_scores)
+
+        decoder.output(selected_ids)
+        decoder.output(selected_generation_scores)
+
+translation_ids, translation_scores = decoder()
+```
+The `decoder.beam_search` is a operator that given the candidates and the scores of translations including the candidates,
+return the result of the beam search algorithm.
+
+In this way, users can customize anything on the inputs or outputs of beam search, for example, two ways to prune some translation prefixes
+
+1. meke the correspondind elements in `topk_generated_scores` zero or some small values, beam_search will discard this candidate.
+2. remove some specific candidate in `selected_ids`
+3. get the final `translation_ids`, remove the translation sequence in it.
+
+The implementation of sequence decoder can reuse the C++ class [RNNAlgorithm](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/paddle/operators/dynamic_recurrent_op.h#L30),
+so the python syntax is quite similar to a [RNN](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/doc/design/block.md#blocks-with-for-and-rnnop).
+
+Both of them are two-level `LoDTensors`
+
+- the first level represents `batch_size` of (source) sentences;
+- the second level represents the candidate ID sets for translation prefix.
+
+for example, 3 source sentences to translate, and has 2, 3, 1 candidates.
+
+Unlike an RNN, in sequence decoder, the previous state and the current state have different LoD and shape,
+a `lod_expand` operator is used to expand the LoD of the previous state to fit the current state.
+
+For example, the previous state
+
+* LoD is `[0, 1, 3][0, 2, 5, 6]`
+* content of tensor is `a1 a2 b1 b2 b3 c1`
+
+the current state stored in `encoder_ctx_expanded`
+
+* LoD is `[0, 2, 7][0 3 5 8 9 11 11]`
+* the content is 
+  - a1 a1 a1 (a1 has 3 candidates, so the state should be copied 3 times for each candidates)
+  - a2 a2
+  - b1 b1 b1
+  - b2
+  - b3 b3
+  - None (c1 has 0 candidates, so c1 is dropped)
+
+Benefit from the relative offset LoD, empty candidate set can be represented naturally.
+
+the status in each time step can be stored in `TensorArray`, and `Pack`ed to a final LoDTensor, the corresponding syntax is 
+
+```python
+decoder.output(selected_ids)
+decoder.output(selected_generation_scores)
+```
+
+the `selected_ids` is the candidate ids for the prefixes, 
+it will be `Packed` by `TensorArray` to a two-level `LoDTensor`,
+the first level represents the source sequences,
+the second level represents generated sequences.
+
+Pack the `selected_scores` will get a `LoDTensor` that stores scores of each candidate of translations.
+
+Pack the `selected_generation_scores` will get a `LoDTensor`, and each tail is the probability of the translation.
+
+## LoD and shape changes during decoding
+
+   +
+
+
+According the image above, the only phrase to change LoD is beam search.
+
+## Beam search design
+The beam search algorthm will be implemented as one method of the sequence decoder, it has 3 inputs
+
+1. `topk_ids`, top K candidate ids for each prefix.
+2. `topk_scores`, the corresponding scores for `topk_ids`
+3. `generated_scores`, the score of the prefixes.
+
+All of the are LoDTensors, so that the sequence affilication is clear.
+Beam search will keep a beam for each prefix and select a smaller candidate set for each prefix.
+
+It will return three variables
+
+1. `selected_ids`, the final candidate beam search function selected for the next step.
+2. `selected_scores`, the scores for the candidates.
+3. `generated_scores`, the updated scores for each prefixes (with the new candidates appended).
+
+## Introducing the LoD-based `Pack` and `Unpack` methods in `TensorArray`
+The `selected_ids`, `selected_scores` and `generated_scores` are LoDTensors,
+and they exist in each time step,
+so it is natural to store them in arrays.
+
+Currently, PaddlePaddle has a module called `TensorArray` which can store an array of tensors,
+the results of beam search are better to store in a `TensorArray`.
+
+The `Pack` and `UnPack` in `TensorArray` are used to package tensors in the array to a `LoDTensor` or split the `LoDTensor` to an array of tensors. 
+It needs some extensions to support pack or unpack an array of `LoDTensors`.
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 4e8d630c2634682ff63b38182108eadebb5c7ff9..d485cdf6109274377ad0057223bdd8401e964aa7 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -21,7 +21,7 @@
 #include "paddle/framework/var_desc.h"
 #include "paddle/operators/net_op.h"
 
-USE_OP(fill_constant);
+USE_NO_KERNEL_OP(fill_constant);
 
 namespace paddle {
 namespace framework {
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index c5ae7b185460c8b0d68ba38bb9db9bd3d3fb14ea..3ec88d7a72c3339bf5e7d0ca3957a3f608f039b7 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -34,6 +34,21 @@ inline DataType ToDataType(std::type_index type) {
   }
 }
 
+inline std::type_index ToTypeIndex(DataType type) {
+  switch (type) {
+    case DataType::FP32:
+      return typeid(float);
+    case DataType::FP64:
+      return typeid(double);
+    case DataType::INT32:
+      return typeid(int);
+    case DataType::INT64:
+      return typeid(int64_t);
+    default:
+      PADDLE_THROW("Not support type %d", type);
+  }
+}
+
 template 
 inline void VisitDataType(DataType type, Visitor visitor) {
   switch (type) {
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 10c785e04c4fa2192f9c95513009cf7d8c123868..53b899a23997b71e723a298ec360a4e018d89878 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -79,6 +79,13 @@ DDim make_ddim(const std::vector& dims) {
   return result;
 }
 
+DDim make_ddim(const std::vector& dims) {
+  std::vector res(dims.size());
+  std::transform(dims.begin(), dims.end(), res.begin(),
+                 [](int d) { return static_cast(d); });
+  return make_ddim(res);
+}
+
 /// @cond HIDDEN
 // XXX For some reason, putting this in an anonymous namespace causes errors
 class DynamicMutableIndexer : public boost::static_visitor {
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index aa773868ab4b68acbc46dfa2cd2569d8b8b7789d..4ca5e49566b7ec006eba80f3f9808bacb1ff2615 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -81,6 +81,8 @@ struct DDim {
  */
 DDim make_ddim(const std::vector& dims);
 
+DDim make_ddim(const std::vector& dims);
+
 /**
  * \brief Make a DDim from an initializer list
  *
diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt
index 4fd72d64a90ae6f16dd1499ceb7fba6e40fe4cea..9b2779b42cad324253dadf27dbff20fd8e8c8e16 100644
--- a/paddle/function/CMakeLists.txt
+++ b/paddle/function/CMakeLists.txt
@@ -45,6 +45,7 @@ if(WITH_GPU)
     add_simple_unittest(BlockExpandOpTest)
     add_simple_unittest(CropOpTest)
     add_simple_unittest(SwitchOpTest)
+    add_simple_unittest(ScaleSubRegionOpTest)
 endif()
 
 add_simple_unittest(Im2ColTest)
diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h
index ba446bf92da264fafa1fb47a2c30da9cb13176ce..370940532ef40335be54a3e6467de0409e923ec4 100644
--- a/paddle/function/FunctionTest.h
+++ b/paddle/function/FunctionTest.h
@@ -110,6 +110,7 @@ public:
         function2_(FunctionBase::funcRegistrar_.createByType(name2)) {
     function1_->init(config);
     function2_->init(config);
+    initArgsCallback_ = nullptr;
   }
 
   ~Compare2Function() {}
@@ -170,6 +171,10 @@ public:
                                       *seq2_));
   }
 
+  void registerInitCallback(std::function callback) {
+    initArgsCallback_ = callback;
+  }
+
   // output need only contains shape, do not contains data.
   void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
     size_t size =
@@ -340,6 +345,10 @@ protected:
         initArg(*func1Inputs_[i]);
       }
 
+      if (initArgsCallback_ != nullptr) {
+        initArgsCallback_(*func1Inputs_[i], i);
+      }
+
       copyArg_(*func1Inputs_[i], *func2Inputs_[i]);
     }
   }
@@ -386,6 +395,7 @@ protected:
   std::shared_ptr seq1_;
   std::shared_ptr seq2_;
   test::CopyArgument copyArg_;
+  std::function initArgsCallback_;
 };
 
 class CpuGpuFuncCompare
diff --git a/paddle/function/ScaleSubRegionOp.cpp b/paddle/function/ScaleSubRegionOp.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a080505d7df83a6c0a9d88fbcb7863fc0e1f7b21
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOp.cpp
@@ -0,0 +1,155 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionOp.h"
+#include "paddle/function/TensorShape.h"
+
+namespace paddle {
+
+template <>
+void ScaleSubRegion(real* outputs,
+                                     const real* inputs,
+                                     const real* indices,
+                                     const TensorShape shape,
+                                     const FuncConfig& conf) {
+  real value = conf.get("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  memcpy(outputs, inputs, number * channel * height * width * sizeof(real));
+
+  for (int n = 0; n < number; ++n) {
+    // indices start from 1
+    int offset = n * 6;
+    for (int c = indices[offset] - 1; c < indices[offset + 1]; ++c) {
+      for (int h = indices[offset + 2] - 1; h < indices[offset + 3]; ++h) {
+        for (int w = indices[offset + 4] - 1; w < indices[offset + 5]; ++w) {
+          int idx = ((n * channel + c) * height + h) * width + w;
+          outputs[idx] *= value;
+        }
+      }
+    }
+  }
+}
+
+template <>
+void ScaleSubRegionGrad(const real* inGrad,
+                                         real* outGrad,
+                                         const real* indices,
+                                         const TensorShape shape,
+                                         const FuncConfig& conf) {
+  real value = conf.get("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  for (int n = 0; n < number; ++n) {
+    for (int c = 0; c < channel; ++c) {
+      for (int h = 0; h < height; ++h) {
+        for (int w = 0; w < width; ++w) {
+          int idx = ((n * channel + c) * height + h) * width + w;
+          int offset = n * 6;
+          if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+              h >= (indices[offset + 2] - 1) &&
+              h <= (indices[offset + 3] - 1) &&
+              w >= (indices[offset + 4] - 1) &&
+              w <= (indices[offset + 5] - 1)) {
+            outGrad[idx] += inGrad[idx] * value;
+          } else {
+            outGrad[idx] += inGrad[idx];
+          }
+        }
+      }
+    }
+  }
+}
+
+/**
+ * \brief For each instance, ScaleSubRegion can be used to multiply a value to
+ *        a specified sub continuous region. By providing start index and end
+ *        index for C/H/W, you can specify the location and shape of the region.
+ *
+ * Argument in this Function:
+ * \param inputs    A 4-D tensor with shape [N, C, H, W], only one input.
+ * \param indices   A 2-D tensor with shape [N, 6], indicates the sub region.
+ * \param outputs   A 4-D tensor with same shape as inputs, output value.
+ */
+template 
+class ScaleSubRegionFunc : public FunctionBase {
+public:
+  void init(const FuncConfig& config) override { conf_ = config; }
+
+  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+    CHECK_EQ(2UL, inputs.size());
+    CHECK_EQ(1UL, outputs.size());
+    CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
+
+    TensorShape shape = inputs[0].shape();
+
+    ScaleSubRegion(outputs[0].data(),
+                           inputs[0].data(),
+                           inputs[1].data(),
+                           shape,
+                           conf_);
+  }
+
+private:
+  FuncConfig conf_;
+};
+
+/**
+ * \brief The backward propagation of ScaleSubRegion Function.
+ *
+ * Argument in this Function:
+ * \param inputs  A 4-D tensor with shape [N, C, H, W], output gradient.
+ * \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
+ * \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value.
+ */
+
+template 
+class ScaleSubRegionGradFunc : public FunctionBase {
+public:
+  void init(const FuncConfig& config) override { conf_ = config; }
+
+  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+    CHECK_EQ(2UL, inputs.size());
+    CHECK_EQ(1UL, outputs.size());
+    CHECK_EQ(outputs[0].getArgType(), ADD_TO);
+
+    TensorShape shape = inputs[0].shape();
+
+    ScaleSubRegionGrad(inputs[0].data(),
+                               outputs[0].data(),
+                               inputs[1].data(),
+                               shape,
+                               conf_);
+  }
+
+private:
+  FuncConfig conf_;
+};
+
+REGISTER_TYPED_FUNC(ScaleSubRegion, CPU, ScaleSubRegionFunc);
+REGISTER_TYPED_FUNC(ScaleSubRegionGrad, CPU, ScaleSubRegionGradFunc);
+#ifdef PADDLE_WITH_CUDA
+REGISTER_TYPED_FUNC(ScaleSubRegion, GPU, ScaleSubRegionFunc);
+REGISTER_TYPED_FUNC(ScaleSubRegionGrad, GPU, ScaleSubRegionGradFunc);
+#endif
+
+}  // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOp.h b/paddle/function/ScaleSubRegionOp.h
new file mode 100644
index 0000000000000000000000000000000000000000..0480c8577f3fbf3bc9e94b635df96a31b103e9e3
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOp.h
@@ -0,0 +1,55 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Function.h"
+
+namespace paddle {
+
+/**
+ * \brief Function to multiply a value to values in specified sub continuous
+ *        region. Indices must be provided to indcate the location and shape of
+ *        the region and the multiplied value is passed by configure variable.
+ *
+ *
+ * \param[out] outputs  Output value.
+ * \param[in]  inputs   Input data which contains NCHW information.
+ * \param[in]  indices  Indices data to indcate the sub region.
+ * \param[in]  shape    Tensor shape of input value.
+ * \param[in]  conf     Configure variable which contains the multiplied value.
+ */
+template 
+void ScaleSubRegion(real* outputs,
+                    const real* inputs,
+                    const real* indices,
+                    const TensorShape shape,
+                    const FuncConfig& conf);
+
+/**
+ * \brief Backward propagation function of ScaleSubRegion.
+ *
+ * \param[out] inGrad   Gradients of previous layer.
+ * \param[in]  outGrad  Output gradient.
+ * \param[in]  indices  Indices data.
+ * \param[in]  shape    The Shape of input tensor.
+ * \param[in]  conf     Configure variable.
+ */
+template 
+void ScaleSubRegionGrad(const real* inGrad,
+                        real* outGrad,
+                        const real* indices,
+                        const TensorShape shape,
+                        const FuncConfig& conf);
+}  // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOpGpu.cu b/paddle/function/ScaleSubRegionOpGpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8aae2e44c3fdc8b516e66ecfd2e04f466a17dde9
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOpGpu.cu
@@ -0,0 +1,116 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionOp.h"
+#include "hl_base.h"
+
+namespace paddle {
+
+__global__ void KeScaleSubRegion(real* outputs,
+                                 const real* inputs,
+                                 const real* indices,
+                                 real value,
+                                 int channel,
+                                 int height,
+                                 int width,
+                                 int nthreads) {
+  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+  if (idx < nthreads) {
+    const int w = idx % width;
+    const int h = (idx / width) % height;
+    const int c = (idx / width / height) % channel;
+    const int n = idx / width / height / channel;
+
+    const int offset = n * 6;
+    if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+        h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
+        w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
+      outputs[idx] = inputs[idx] * value;
+    } else {
+      outputs[idx] = inputs[idx];
+    }
+  }
+}
+
+template <>
+void ScaleSubRegion(real* outputs,
+                                     const real* inputs,
+                                     const real* indices,
+                                     const TensorShape shape,
+                                     const FuncConfig& conf) {
+  real value = conf.get("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  size_t nth = number * channel * height * width;
+  int blockSize = 1024;
+  int gridSize = (nth + blockSize - 1) / blockSize;
+
+  KeScaleSubRegion<<>>(
+      outputs, inputs, indices, value, channel, height, width, nth);
+  CHECK_SYNC("ScaleSubRegion");
+}
+
+__global__ void KeScaleSubRegionDiff(const real* inGrad,
+                                     real* outGrad,
+                                     const real* indices,
+                                     real value,
+                                     int channel,
+                                     int height,
+                                     int width,
+                                     int nthreads) {
+  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+  if (idx < nthreads) {
+    const int w = idx % width;
+    const int h = (idx / width) % height;
+    const int c = (idx / width / height) % channel;
+    const int n = idx / width / height / channel;
+
+    const int offset = n * 6;
+    if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+        h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
+        w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
+      outGrad[idx] += inGrad[idx] * value;
+    } else {
+      outGrad[idx] += inGrad[idx];
+    }
+  }
+}
+
+template <>
+void ScaleSubRegionGrad(const real* inGrad,
+                                         real* outGrad,
+                                         const real* indices,
+                                         const TensorShape shape,
+                                         const FuncConfig& conf) {
+  real value = conf.get("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  size_t nth = number * channel * height * width;
+  int blockSize = 1024;
+  int gridSize = (nth + blockSize - 1) / blockSize;
+
+  KeScaleSubRegionDiff<<>>(
+      inGrad, outGrad, indices, value, channel, height, width, nth);
+  CHECK_SYNC("ScaleSubRegionGrad");
+}
+
+}  // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOpTest.cpp b/paddle/function/ScaleSubRegionOpTest.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..43331f258dddaa43cbc8cc77519e299de7e98290
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOpTest.cpp
@@ -0,0 +1,72 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include 
+#include "FunctionTest.h"
+
+namespace paddle {
+
+TEST(ScaleSubRegion, real) {
+  for (size_t numSamples : {5, 32}) {
+    for (size_t channels : {5, 32}) {
+      for (size_t imgSizeH : {5, 33}) {
+        for (size_t imgSizeW : {5, 32}) {
+          for (real value : {-0.5, 0.0, 0.5}) {
+            for (bool firstHalf : {false, true}) {
+              VLOG(3) << " numSamples=" << numSamples
+                      << " channels=" << channels << " imgSizeH=" << imgSizeH
+                      << " imgSizeW=" << imgSizeW;
+
+              for (bool testGrad : {false, true}) {
+                CpuGpuFuncCompare compare(
+                    testGrad ? "ScaleSubRegionGrad" : "ScaleSubRegion",
+                    FuncConfig().set("value", value));
+
+                TensorShape shape{numSamples, channels, imgSizeH, imgSizeW};
+                TensorShape indicesShape{numSamples, 6};
+
+                compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape));
+                compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, indicesShape));
+
+                compare.registerInitCallback([=](BufferArg& arg, size_t index) {
+                  if (index == 1) {
+                    real* data = (real*)arg.data();
+
+                    for (size_t i = 0; i < numSamples; ++i) {
+                      size_t offset = i * 6;
+                      data[offset] = firstHalf ? 1 : channels / 2;
+                      data[offset + 1] = firstHalf ? channels / 2 : channels;
+                      data[offset + 2] = firstHalf ? 1 : imgSizeH / 2;
+                      data[offset + 3] = firstHalf ? imgSizeH / 2 : imgSizeH;
+                      data[offset + 4] = firstHalf ? 1 : imgSizeW / 2;
+                      data[offset + 5] = firstHalf ? imgSizeW / 2 : imgSizeW;
+                    }
+                  }
+                });
+
+                compare.addOutputs(
+                    BufferArg(
+                        VALUE_TYPE_FLOAT, shape, testGrad ? ADD_TO : ASSIGN_TO),
+                    testGrad ? ADD_TO : ASSIGN_TO);
+                compare.run();
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+}  // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp
index 8eb700723f2cf7dda969739bb5e3d48358d278a0..6ffe4fbec643e50d27924a989875454d307f5b9b 100644
--- a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp
@@ -62,16 +62,14 @@ void MKLDNNAddtoLayer::resetFwd(std::vector& pipeline,
                                 MKLDNNMatrixPtr& wgt,
                                 MKLDNNMatrixPtr& bias,
                                 MKLDNNMatrixPtr& out) {
-  if (biases_) {
-    LOG(FATAL) << "not implemented yet";
-  }
-  resetFwdBuffers(inVals_, out);
+  resetFwdBuffers(inVals_, bias, out);
   in = inVals_[0];
 
   std::shared_ptr fwdPD;
-  resetFwdPD(fwdPD, inVals_, out);
+  std::shared_ptr biasPD;
+  resetFwdPD(fwdPD, biasPD, inVals_, bias, out);
 
-  resetFwdPipeline(pipeline, fwdPD, inVals_, out);
+  resetFwdPipeline(pipeline, fwdPD, biasPD, inVals_, bias, out);
 }
 
 void MKLDNNAddtoLayer::resetBwd(std::vector& pipeline,
@@ -79,7 +77,7 @@ void MKLDNNAddtoLayer::resetBwd(std::vector& pipeline,
                                 MKLDNNMatrixPtr& wgt,
                                 MKLDNNMatrixPtr& bias,
                                 MKLDNNMatrixPtr& out) {
-  resetBwdBuffers(inGrads_, out);
+  resetBwdBuffers(inGrads_, bias, out);
   in = inGrads_[0];
 
   // backward only need share output grad to input grad
@@ -89,6 +87,20 @@ void MKLDNNAddtoLayer::resetBwd(std::vector& pipeline,
       inputLayers_[i]->getOutputGrad()->setData(inGrads_[i]->getData());
     }
   }
+
+  // backward bias
+  bwdBias_ = nullptr;
+  if (bias) {
+    std::vector scales(bs_, 1.0);
+    std::vector srcPDs(bs_, bias->getPrimitiveDesc());
+    auto biasPD = sum::primitive_desc(bias->getMemoryDesc(), scales, srcPDs);
+    std::vector srcs;
+    for (size_t i = 0; i < grads_.size(); ++i) {
+      srcs.push_back(*(grads_[i]));
+    }
+    bwdBias_.reset(new sum(biasPD, srcs, *bias));
+    pipeline.push_back(*bwdBias_);
+  }
 }
 
 void MKLDNNAddtoLayer::updateWeights(const UpdateCallback& callback) {
@@ -97,7 +109,25 @@ void MKLDNNAddtoLayer::updateWeights(const UpdateCallback& callback) {
   }
 }
 
+void MKLDNNAddtoLayer::prepareBias(MKLDNNMatrixPtr& bias,
+                                   const MatrixPtr& biasMat,
+                                   const MKLDNNMatrixPtr& out,
+                                   std::vector& outs) {
+  auto pd = MKLDNNMatrix::createPrimitiveDesc(
+      {(int)layerSize_}, memory::format::x, engine_);
+  bias = MKLDNNMatrix::create(pd, biasMat);
+  outs.clear();
+  real* data = out->getData();
+  CHECK_EQ(bs_ * layerSize_, out->getElementCnt());
+  for (int i = 0; i < bs_; ++i) {
+    MatrixPtr tmp =
+        Matrix::create(data + i * layerSize_, 1, layerSize_, false, false);
+    outs.push_back(MKLDNNMatrix::create(bias->getPrimitiveDesc(), tmp));
+  }
+}
+
 void MKLDNNAddtoLayer::resetFwdBuffers(std::vector& inputs,
+                                       MKLDNNMatrixPtr& bias,
                                        MKLDNNMatrixPtr& out) {
   inputs.resize(inputLayers_.size());
   for (size_t i = 0; i < inputs.size(); i++) {
@@ -110,12 +140,20 @@ void MKLDNNAddtoLayer::resetFwdBuffers(std::vector& inputs,
   }
 
   resetOutValue(out, inputs[0]->getPrimitiveDesc());
+
+  if (biases_ && biases_->getW()) {
+    prepareBias(bias, biases_->getW(), out, vals_);
+  } else {
+    bias = nullptr;
+  }
 }
 
 void MKLDNNAddtoLayer::resetFwdPD(std::shared_ptr& pd,
+                                  std::shared_ptr& biasPD,
                                   std::vector& inputs,
+                                  MKLDNNMatrixPtr bias,
                                   MKLDNNMatrixPtr out) {
-  std::vector scales(inputs.size(), 1.0);
+  std::vector scales(inputs.size(), 1.0);
   std::vector srcPDs;
   for (size_t i = 0; i < inputs.size(); i++) {
     srcPDs.push_back(inputs[i]->getPrimitiveDesc());
@@ -123,12 +161,23 @@ void MKLDNNAddtoLayer::resetFwdPD(std::shared_ptr& pd,
   CHECK(out);
   pd.reset(new sum::primitive_desc(out->getMemoryDesc(), scales, srcPDs));
   CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
+
+  biasPD = nullptr;
+  if (bias) {
+    std::vector scales(2, 1.0);
+    std::vector srcPDs(2, bias->getPrimitiveDesc());
+    biasPD.reset(
+        new sum::primitive_desc(bias->getMemoryDesc(), scales, srcPDs));
+    CHECK_PRIMITIVE_DESC_EQ(bias, biasPD->dst_primitive_desc());
+  }
 }
 
 void MKLDNNAddtoLayer::resetFwdPipeline(
     std::vector& pipeline,
     std::shared_ptr& pd,
+    std::shared_ptr& biasPD,
     std::vector& inputs,
+    MKLDNNMatrixPtr& bias,
     MKLDNNMatrixPtr& out) {
   std::vector srcs;
   for (size_t i = 0; i < inputs.size(); i++) {
@@ -136,9 +185,23 @@ void MKLDNNAddtoLayer::resetFwdPipeline(
   }
   fwd_.reset(new sum(*pd, srcs, *out));
   pipeline.push_back(*fwd_);
+
+  fwdBias_.clear();
+  if (biasPD == nullptr || bias == nullptr) {
+    return;
+  }
+  fwdBias_.resize(vals_.size());
+  for (size_t i = 0; i < vals_.size(); ++i) {
+    std::vector srcs;
+    srcs.push_back(*(vals_[i]));
+    srcs.push_back(*bias);
+    fwdBias_[i].reset(new sum(*biasPD, srcs, *vals_[i]));
+    pipeline.push_back(*fwdBias_[i]);
+  }
 }
 
 void MKLDNNAddtoLayer::resetBwdBuffers(std::vector& inputs,
+                                       MKLDNNMatrixPtr& bias,
                                        MKLDNNMatrixPtr& out) {
   CHECK(outVal_);
   resetOutGrad(out, outVal_->getPrimitiveDesc());
@@ -149,6 +212,12 @@ void MKLDNNAddtoLayer::resetBwdBuffers(std::vector& inputs,
     resetInGrad(inputs[i], inVal_->getPrimitiveDesc(), i);
     CHECK_PRIMITIVE_DESC_EQ(inputs[i], out->getPrimitiveDesc());
   }
+
+  if (biases_ && biases_->getWGrad()) {
+    prepareBias(bias, biases_->getWGrad(), out, grads_);
+  } else {
+    bias = nullptr;
+  }
 }
 
 }  // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNAddtoLayer.h b/paddle/gserver/layers/MKLDNNAddtoLayer.h
index 15f74ec5bdf3d1e4ae5e09051be6be418590a67a..24504b7b4f50726e2b2757ca3029461cdc27b411 100644
--- a/paddle/gserver/layers/MKLDNNAddtoLayer.h
+++ b/paddle/gserver/layers/MKLDNNAddtoLayer.h
@@ -32,9 +32,15 @@ protected:
   // layer size == ic * ih * iw == oc * oh *ow, and can not be changed
   size_t layerSize_;
 
-  // TODO(TJ): this part has not been optimized by MKL-DNN
   std::unique_ptr biases_;
 
+  // buffers for adding bias
+  std::vector vals_;
+  std::vector grads_;
+  // primitives for adding bias
+  std::vector> fwdBias_;
+  std::shared_ptr bwdBias_;
+
 public:
   explicit MKLDNNAddtoLayer(const LayerConfig& config) : MKLDNNLayer(config) {}
 
@@ -91,20 +97,34 @@ protected:
    *                    reset pipeline.
    */
   void resetFwdBuffers(std::vector& inputs,
+                       MKLDNNMatrixPtr& bias,
                        MKLDNNMatrixPtr& out);
   void resetFwdPD(std::shared_ptr& pd,
+                  std::shared_ptr& biasPD,
                   std::vector& inputs,
+                  MKLDNNMatrixPtr bias,
                   MKLDNNMatrixPtr out);
   void resetFwdPipeline(std::vector& pipeline,
                         std::shared_ptr& pd,
+                        std::shared_ptr& biasPD,
                         std::vector& inputs,
+                        MKLDNNMatrixPtr& bias,
                         MKLDNNMatrixPtr& out);
 
   /**
    * Backward functions: reset buffers(inputs, output, bias)
    */
   void resetBwdBuffers(std::vector& inputs,
+                       MKLDNNMatrixPtr& bias,
                        MKLDNNMatrixPtr& out);
+
+  /**
+   * prepare for bias
+   */
+  void prepareBias(MKLDNNMatrixPtr& bias,
+                   const MatrixPtr& biasMat,
+                   const MKLDNNMatrixPtr& out,
+                   std::vector& outs);
 };
 
 }  // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
index 9b0ae20f089e34a719883bc65e88e33ab9334e39..ed3887cbf653878623764a310c9f364f4d8be27f 100644
--- a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
@@ -119,7 +119,7 @@ void MKLDNNBatchNormLayer::reshape(
     int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
   reshapeInput(bs, ih, iw);
   oh = ih;
-  ow = ow;
+  ow = iw;
   // ic_ and oc can not be changed
   CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
       << "Input channel can not be changed";
diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp
index 82ef344c7b2aa0093a5f0a28780592dea5d51efe..e75ac5ba4647a8267b7bc189893bd7adb5c3053f 100644
--- a/paddle/gserver/layers/MKLDNNLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNLayer.cpp
@@ -287,7 +287,7 @@ void MKLDNNLayer::resetMergeGrad(MKLDNNMatrixPtr& out) {
     return;
   }
   CHECK(out) << "should have reset internal ouput grad";
-  std::vector scales(outputMap_.size(), 1.0);
+  std::vector scales(outputMap_.size(), 1.0);
   std::vector srcPDs;
   std::vector srcs;
   for (auto it = outputMap_.begin(); it != outputMap_.end(); ++it) {
diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.cpp b/paddle/gserver/layers/ScaleSubRegionLayer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..aa6778aef4e893208fd064ca22e217c6c4d960f9
--- /dev/null
+++ b/paddle/gserver/layers/ScaleSubRegionLayer.cpp
@@ -0,0 +1,78 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionLayer.h"
+#include "paddle/utils/Stat.h"
+namespace paddle {
+
+REGISTER_LAYER(scale_sub_region, ScaleSubRegionLayer);
+
+bool ScaleSubRegionLayer::init(const LayerMap& layerMap,
+                               const ParameterMap& parameterMap) {
+  Layer::init(layerMap, parameterMap);
+  CHECK_EQ(static_cast(inputLayers_.size()), 2);
+  auto& conf = config_.inputs(0).scale_sub_region_conf();
+  value_ = conf.value();
+
+  createFunction(forward_, "ScaleSubRegion", FuncConfig().set("value", value_));
+  createFunction(
+      backward_, "ScaleSubRegionGrad", FuncConfig().set("value", value_));
+
+  return true;
+}
+
+void ScaleSubRegionLayer::forward(PassType passType) {
+  Layer::forward(passType);
+  auto in0 = getInput(0);
+  imgH_ = in0.getFrameHeight();
+  imgW_ = in0.getFrameWidth();
+  if (imgH_ == 0 || imgW_ == 0) {
+    auto& conf = config_.inputs(0).scale_sub_region_conf();
+    imgH_ = conf.image_conf().img_size_y();
+    imgW_ = conf.image_conf().img_size();
+  }
+  MatrixPtr imgV = in0.value;
+  size_t batchSize = imgV->getHeight();
+  size_t spatialSize = imgH_ * imgW_;
+  channelsNum_ = imgV->getWidth() / spatialSize;
+  shape_ = TensorShape({batchSize, channelsNum_, imgH_, imgW_});
+
+  resetOutput(batchSize, imgV->getWidth());
+  auto& out = getOutput();
+  out.setFrameHeight(imgH_);
+  out.setFrameWidth(imgW_);
+
+  MatrixPtr indicesV = getInputValue(1);
+  indicesShape_ = TensorShape({batchSize, 6});
+
+  REGISTER_TIMER_INFO("ScaleSubRegionForward", getName().c_str());
+  BufferArgs inArgs;
+  BufferArgs outArgs;
+  inArgs.addArg(*imgV, shape_);
+  inArgs.addArg(*indicesV, indicesShape_);
+  outArgs.addArg(*out.value, shape_, ASSIGN_TO);
+  forward_[0]->calc(inArgs, outArgs);
+}
+
+void ScaleSubRegionLayer::backward(const UpdateCallback& callback) {
+  REGISTER_TIMER_INFO("ScaleSubRegionBackward", getName().c_str());
+  BufferArgs inArgs;
+  BufferArgs outArgs;
+  inArgs.addArg(*getOutputGrad(), shape_);
+  inArgs.addArg(*getInputValue(1), indicesShape_);
+  outArgs.addArg(*getInputGrad(0), shape_, ADD_TO);
+  backward_[0]->calc(inArgs, outArgs);
+}
+
+}  // namespace paddle
diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.h b/paddle/gserver/layers/ScaleSubRegionLayer.h
new file mode 100644
index 0000000000000000000000000000000000000000..a27c56de93bb6fdde0f95cd4c5abe5dfabe4e858
--- /dev/null
+++ b/paddle/gserver/layers/ScaleSubRegionLayer.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Layer.h"
+
+namespace paddle {
+
+/**
+ * \brief  For each instance, this layer can be used to multiply a value to a
+ *         specified sub continuous region. By providing start index and end
+ *         index for C/H/W, you can specify the location and shape of the
+ *         region.
+ *
+ *         input_0: Input value.
+ *         input_1: Indices value to specify the location an shape of the
+ *                  region.
+ */
+class ScaleSubRegionLayer : public Layer {
+public:
+  explicit ScaleSubRegionLayer(const LayerConfig& config) : Layer(config) {}
+
+  ~ScaleSubRegionLayer() {}
+
+  bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
+
+  void forward(PassType passType);
+
+  void backward(const UpdateCallback& callback = nullptr);
+
+protected:
+  TensorShape shape_;
+  TensorShape indicesShape_;
+  size_t imgH_;
+  size_t imgW_;
+  size_t channelsNum_;
+  real value_;
+};
+
+}  // namespace paddle
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index 1a46fb49153a0aa4228f58db481b950bc2d6de83..df73e6781533def5641635e9dfa9c9e4e8a0b57f 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -53,7 +53,7 @@ TEST(Operator, dot_mul) {
 TEST(Projection, context) {
   for (auto contextStart : {-5, -3, -1, 0, 3}) {
     for (auto contextLength : {1, 2, 5, 7}) {
-      for (auto batchSize : {1, 2, 5, 20, 50}) {
+      for (auto batchSize : {1, 2, 5, 20}) {
         for (auto trainablePadding : {false, true}) {
           LOG(INFO) << " contextStart=" << contextStart
                     << " contextLength=" << contextLength
@@ -585,14 +585,14 @@ TEST(Layer, maxoutLayer) {
 }
 void testFcLayer(string format, size_t nnz) {
   TestConfig config;
-  config.biasSize = 4096;
+  config.biasSize = 1024;
   config.layerConfig.set_type("fc");
-  config.layerConfig.set_size(4096);
+  config.layerConfig.set_size(1024);
   config.layerConfig.set_active_type("sigmoid");
   config.layerConfig.set_drop_rate(0.1);
 
   config.inputDefs.push_back(
-      {INPUT_DATA, "layer_0", 8192, nnz, ParaSparse(format)});
+      {INPUT_DATA, "layer_0", 2048, nnz, ParaSparse(format)});
   config.layerConfig.add_inputs();
 
   LOG(INFO) << config.inputDefs[0].sparse.sparse << " "
@@ -609,9 +609,9 @@ void testFcLayer(string format, size_t nnz) {
 }
 
 TEST(Layer, fcLayer) {
-  testFcLayer("", 4096 * 4096 * 2);
-  testFcLayer("csc", 4096 * 40);
-  testFcLayer("csr", 4096 * 40);
+  testFcLayer("", 1024 * 1024 * 2);
+  testFcLayer("csc", 1024 * 10);
+  testFcLayer("csr", 1024 * 10);
 }
 
 TEST(Layer, SelectiveFullyConnectedLayer) {
@@ -1995,7 +1995,7 @@ TEST(Layer, multibox_loss) {
 TEST(Layer, TransLayer) {
   TestConfig config;
   const int height = 128;
-  const int width = 1028;
+  const int width = 256;
   config.layerConfig.set_type("trans");
   config.layerConfig.set_size(width);
 
@@ -2358,6 +2358,38 @@ TEST(Layer, ScaleShiftLayer) {
   }
 }
 
+TEST(Layer, ScaleSubRegionLayer) {
+  const size_t batchSize = 64;
+  const size_t size = 4096;
+  TestConfig config;
+  config.layerConfig.set_type("scale_sub_region");
+  config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
+  MatrixPtr indicesV = Matrix::create(batchSize, 6, false, false);
+  auto* data = indicesV->getData();
+  for (size_t i = 0; i < batchSize; ++i) {
+    data[i * 2] = 2;
+    data[i * 2 + 1] = 4;
+    data[i * 2 + 2] = 16;
+    data[i * 2 + 3] = 32;
+    data[i * 2 + 4] = 16;
+    data[i * 2 + 5] = 32;
+  }
+  config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, "indices", indicesV, {}});
+  LayerInputConfig* input = config.layerConfig.add_inputs();
+  ScaleSubRegionConfig* scaleSubRegionConf =
+      input->mutable_scale_sub_region_conf();
+  ImageConfig* imgConf = scaleSubRegionConf->mutable_image_conf();
+  imgConf->set_img_size(32);
+  imgConf->set_img_size_y(32);
+  imgConf->set_channels(4);
+  scaleSubRegionConf->set_value(2.0);
+  config.layerConfig.add_inputs();
+
+  for (auto useGpu : {false, true}) {
+    testLayerGrad(config, "scale_sub_region", batchSize, false, useGpu, false);
+  }
+}
+
 int main(int argc, char** argv) {
   testing::InitGoogleTest(&argc, argv);
   initMain(argc, argv);
diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp
index 2e8d9f3333b36005c9b3b28449c76a4a44c74cc6..a0e039c2a33b586e21775ad06c1278a10804d654 100644
--- a/paddle/gserver/tests/test_MKLDNN.cpp
+++ b/paddle/gserver/tests/test_MKLDNN.cpp
@@ -269,6 +269,7 @@ void testBatchNormLayer(const testBatchNormDesc& pm) {
 TEST(MKLDNNLayer, BatchNormLayer) {
   testBatchNormLayer({4, 10, 6, 6});
   testBatchNormLayer({16, 32, 16, 16});
+  testBatchNormLayer({4, 16, 8, 10});
 }
 
 struct testImageDesc {
@@ -300,13 +301,8 @@ void testAddtoLayer(const testImageDesc& pm, const size_t nInputs) {
   TestConfig dnnConfig;
   getAddtoConfig(dnnConfig, pm, nInputs);
   dnnConfig.layerConfig.set_type("mkldnn_addto");
-  // TODO(TJ): test with bias
-  for (auto withBias : {false}) {
-    if (withBias) {
-      dnnConfig.biasSize = pm.ic * pm.ih * pm.iw;
-    } else {
-      dnnConfig.biasSize = 0;
-    }
+  for (auto withBias : {false, true}) {
+    dnnConfig.biasSize = withBias ? pm.ic * pm.ih * pm.iw : 0;
     RUN_MKLDNN_TEST_LAYER(dnnConfig, "addto", pm)
   }
 }
diff --git a/paddle/math/MathFunctions.cpp b/paddle/math/MathFunctions.cpp
index c2f17beeb87942ea681f5d388659c0d280157b26..ba86eacbb5d53ee43a60d2cd1dd922333a5d48f0 100644
--- a/paddle/math/MathFunctions.cpp
+++ b/paddle/math/MathFunctions.cpp
@@ -206,7 +206,7 @@ double dotProduct(const int n, const double* x, const double* y) {
 }
 #endif
 
-#if defined(PADDLE_USE_MKL) || defined(PADDLE_USE_MKLML)
+#if defined(PADDLE_USE_MKLML)
 
 template <>
 void vExp(const int n, const float* a, float* r) {
@@ -295,38 +295,6 @@ template void vAdd(const int n, const double* a, const double* b, double* r);
 
 #endif
 
-#ifdef PADDLE_USE_MKL
-template <>
-void vInvSqrt(const int n, const float* a, float* r) {
-  vsInvSqrt(n, a, r);
-}
-
-template <>
-void vInvSqrt(const int n, const double* a, double* r) {
-  vdInvSqrt(n, a, r);
-}
-
-template <>
-void vLog1p(const int n, const float* a, float* r) {
-  vsLog1p(n, a, r);
-}
-
-template <>
-void vLog1p(const int n, const double* a, double* r) {
-  vdLog1p(n, a, r);
-}
-
-template <>
-void vTanh(const int n, const float* a, float* r) {
-  vsTanh(n, a, r);
-}
-
-template <>
-void vTanh(const int n, const double* a, double* r) {
-  vdTanh(n, a, r);
-}
-#else
-
 DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
 template 
 void vInvSqrt(const int n, const T* a, T* r) {
@@ -357,6 +325,4 @@ template void vLog1p(const int n, const double* a, double* r);
 template void vTanh(const int n, const float* a, float* r);
 template void vTanh(const int n, const double* a, double* r);
 
-#endif
-
 }  // namespace paddle
diff --git a/paddle/math/MathFunctions.h b/paddle/math/MathFunctions.h
index 8193aa4adffc0409d8ea68417c68fa153a2942d8..f6e77029bdd75a602f88b688ca810f47ba4ee615 100644
--- a/paddle/math/MathFunctions.h
+++ b/paddle/math/MathFunctions.h
@@ -21,11 +21,6 @@ limitations under the License. */
 #include 
 #endif
 
-#ifdef PADDLE_USE_MKL
-#include 
-#include 
-#endif
-
 #if defined(PADDLE_USE_ATLAS) || defined(PADDLE_USE_VECLIB)
 extern "C" {
 #include 
diff --git a/paddle/math/tests/TensorCheck.h b/paddle/math/tests/TensorCheck.h
index 5bc4a03067a75527fa30e5bb5526f93dc7b9fdcc..b998e5772e70d0a0ec79dc4064dcbaa2c302efd2 100644
--- a/paddle/math/tests/TensorCheck.h
+++ b/paddle/math/tests/TensorCheck.h
@@ -169,7 +169,7 @@ void TensorCheck(AssertEq compare,
       count++;
     }
   }
-  EXPECT_EQ(count, 0) << "There are " << count << " different element.";
+  EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
 }
 
 template 
diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu
index d0c4c0d25d6f4e3ab7acd72d62a8a17fa102637b..1776f33105367447759aa91c25263dfc53bd2f99 100644
--- a/paddle/operators/accuracy_op.cu
+++ b/paddle/operators/accuracy_op.cu
@@ -65,7 +65,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
 
     size_t num_samples = inference->dims()[0];
     size_t infer_width = inference->dims()[1];
-    cudaMemset((void**)&accuracy_data, 0, sizeof(float));
+    PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
 
     if (num_samples == 0) {
       return;
diff --git a/paddle/operators/array_operator.h b/paddle/operators/array_operator.h
new file mode 100644
index 0000000000000000000000000000000000000000..666043e824f885e9c0e79e319d0a38ba108c209a
--- /dev/null
+++ b/paddle/operators/array_operator.h
@@ -0,0 +1,50 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+#include "paddle/framework/lod_tensor_array.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+class ArrayOp : public framework::OperatorBase {
+ public:
+  ArrayOp(const std::string &type, const framework::VariableNameMap &inputs,
+          const framework::VariableNameMap &outputs,
+          const framework::AttributeMap &attrs)
+      : OperatorBase(type, inputs, outputs, attrs) {}
+
+ protected:
+  size_t GetOffset(const framework::Scope &scope,
+                   const platform::DeviceContext &dev_ctx) const {
+    auto *i = scope.FindVar(Input("I"));
+    PADDLE_ENFORCE(i != nullptr, "I must be set");
+    auto &i_tensor = i->Get();
+    PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
+    size_t offset;
+    if (platform::is_gpu_place(i_tensor.place())) {
+      // FIXME: Avoid copy from GPU to CPU
+      framework::Tensor t;
+      t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
+      dev_ctx.Wait();
+      offset = static_cast(*t.data());
+    } else {
+      offset = static_cast(*i_tensor.data());
+    }
+    return offset;
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/array_to_lod_tensor_op.cc b/paddle/operators/array_to_lod_tensor_op.cc
index 6cd9c06b8ae3d3b17be83268c2f5d4002705b111..c0903bb4e5ca7f160e19eefab99af7e3e4a8ed76 100644
--- a/paddle/operators/array_to_lod_tensor_op.cc
+++ b/paddle/operators/array_to_lod_tensor_op.cc
@@ -140,6 +140,23 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
                    "ArrayToLoDTensorOp must has input X.");
     PADDLE_ENFORCE(context->HasInput("RankTable"),
                    "ArrayToLoDTensorOp must has input RankTable.");
+    context->SetOutputDim("Out", context->GetInputDim("X"));
+  }
+};
+
+class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
+ public:
+  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+  std::unique_ptr Apply() const override {
+    auto *grad_op = new framework::OpDescBind();
+    grad_op->SetType("lod_tensor_to_array");
+    grad_op->SetInput("X", OutputGrad("Out"));
+    grad_op->SetInput("RankTable", Input("RankTable"));
+    grad_op->SetOutput("Out", InputGrad("X"));
+    grad_op->SetAttrMap(Attrs());
+    return std::unique_ptr(grad_op);
   }
 };
 
@@ -149,4 +166,5 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
 namespace ops = paddle::operators;
 REGISTER_OPERATOR(array_to_lod_tensor, ops::ArrayToLoDTensorOp,
                   ops::ArrayToLoDTensorOpProtoMaker,
-                  ops::ArrayToLoDTensorInferShape);
+                  ops::ArrayToLoDTensorInferShape,
+                  ops::ArrayToLoDTensorGradMaker);
diff --git a/paddle/operators/clip_by_norm_op.cc b/paddle/operators/clip_by_norm_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..d9fc532e39500fa397be80396b075e866bad9362
--- /dev/null
+++ b/paddle/operators/clip_by_norm_op.cc
@@ -0,0 +1,70 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/clip_by_norm_op.h"
+
+namespace paddle {
+namespace operators {
+
+class ClipByNormOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"),
+                   "Input(X) of ClipByNormOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput("Out"),
+                   "Output(Out) of ClipByNormOp should not be null.");
+    auto max_norm = ctx->Attrs().Get("max_norm");
+    PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
+    auto x_dims = ctx->GetInputDim("X");
+    ctx->SetOutputDim("Out", x_dims);
+    ctx->ShareLoD("X", /*->*/ "Out");
+  }
+};
+
+class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  ClipByNormOpMaker(framework::OpProto* proto,
+                    framework::OpAttrChecker* op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X",
+             "(Tensor) The input of clip_by_norm op."
+             "The number of dimensions must be between [1, 9].");
+    AddOutput("Out",
+              "(Tensor) The output of clip_by_norm op with shape as input(X)");
+    AddAttr("max_norm", "(float) The maximum norm value.");
+    AddComment(R"DOC(
+ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'. 
+If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be 
+the same as 'X'. If the L2 norm of 'X' is greater than 'max_norm', 'X' will 
+be linearly scaled to make the L2 norm of 'Out' equal to 'max_norm', as 
+shown in the following formula:
+
+'Out' = 'max_norm' * 'X' / norm('X'),
+
+where norm('X') represents the L2 norm of 'X'.
+)DOC");
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
+                             ops::ClipByNormOpMaker);
+REGISTER_OP_CPU_KERNEL(
+    clip_by_norm, ops::ClipByNormKernel);
diff --git a/paddle/operators/fill_constant_op.cu b/paddle/operators/clip_by_norm_op.cu
similarity index 63%
rename from paddle/operators/fill_constant_op.cu
rename to paddle/operators/clip_by_norm_op.cu
index bca402a8b988b570a083e9ce253342304f4b8946..2593a24ebbf56ecd286a726e527d2414247576e8 100644
--- a/paddle/operators/fill_constant_op.cu
+++ b/paddle/operators/clip_by_norm_op.cu
@@ -12,13 +12,8 @@
    See the License for the specific language governing permissions and
    limitations under the License. */
 
-#define EIGEN_USE_GPU
-#include "paddle/framework/op_registry.h"
-#include "paddle/operators/fill_constant_op.h"
+#include "paddle/operators/clip_by_norm_op.h"
 
 namespace ops = paddle::operators;
 REGISTER_OP_GPU_KERNEL(
-    fill_constant, ops::FillConstantOpKernel,
-    ops::FillConstantOpKernel,
-    ops::FillConstantOpKernel,
-    ops::FillConstantOpKernel);
+    clip_by_norm, ops::ClipByNormKernel);
diff --git a/paddle/operators/clip_by_norm_op.h b/paddle/operators/clip_by_norm_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..b26476cae9b5b2fa290bc9186b9a64c48ba703d6
--- /dev/null
+++ b/paddle/operators/clip_by_norm_op.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/transform.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template 
+using EigenVector = framework::EigenVector;
+
+template 
+class ClipByNormKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto max_norm = context.Attr("max_norm");
+    auto* input = context.Input("X");
+    auto* output = context.Output("Out");
+    output->mutable_data(context.GetPlace());
+
+    auto x = EigenVector::Flatten(*input);
+    auto out = EigenVector::Flatten(*output);
+    auto x_norm = x.square().sum().sqrt();
+    auto place = context.GetEigenDevice();
+
+    auto temp = (x_norm <= max_norm).template cast().eval();
+    auto scaling = temp + (static_cast(1) - temp) * max_norm / x_norm;
+    Eigen::array one_dim{{1}};
+    Eigen::DSizes m_dsize(input->numel());
+    out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc
index 8b425d14df3bc484437dc72f29abf13b887006bd..716b5ee92d0d8737d2069460f53989f691ff7c77 100644
--- a/paddle/operators/compare_op.cc
+++ b/paddle/operators/compare_op.cc
@@ -14,6 +14,7 @@
 
 #include "paddle/operators/compare_op.h"
 #include "paddle/framework/op_registry.h"
+
 namespace paddle {
 namespace operators {
 template 
@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
   }
 };
 
+class CompareOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  framework::OpKernelType GetKernelType(
+      const framework::ExecutionContext &ctx) const override {
+    framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
+    // CompareOp kernel's device type is decided by input tensor place
+    kt.place_ = ctx.Input("X")->place();
+    return kt;
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
-#define REGISTER_LOGICAL_OP(op_type, _equation)                               \
-  struct _##op_type##Comment {                                                \
-    static char type[];                                                       \
-    static char equation[];                                                   \
-  };                                                                          \
-  char _##op_type##Comment::type[]{#op_type};                                 \
-  char _##op_type##Comment::equation[]{_equation};                            \
-  REGISTER_OP_WITH_KERNEL(                                                    \
-      op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
-      ::paddle::operators::CompareOpInferShape<_##op_type##Comment>,          \
+#define REGISTER_LOGICAL_OP(op_type, _equation)                      \
+  struct _##op_type##Comment {                                       \
+    static char type[];                                              \
+    static char equation[];                                          \
+  };                                                                 \
+  char _##op_type##Comment::type[]{#op_type};                        \
+  char _##op_type##Comment::equation[]{_equation};                   \
+  REGISTER_OPERATOR(                                                 \
+      op_type, ::paddle::operators::CompareOp,                       \
+      ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
+      ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
       ::paddle::framework::EmptyGradOpMaker);
 
 REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
diff --git a/paddle/operators/fill_constant_batch_size_like_op.cc b/paddle/operators/fill_constant_batch_size_like_op.cc
index f86ee3c3d88670c0e43f20fdf35b8424438e0486..85871ebbfcd8ee38ef5e8078d1d6cb6bdda46a7b 100644
--- a/paddle/operators/fill_constant_batch_size_like_op.cc
+++ b/paddle/operators/fill_constant_batch_size_like_op.cc
@@ -75,10 +75,10 @@ class FillConstantBatchSizeLikeOpMaker
               "with the specified value");
     AddAttr>("shape", "(vector) The shape of the output");
     AddAttr("input_dim_idx",
-                 "(int, default 0) the index of input's batch size dimension")
+                 "(int, default 0) The index of input's batch size dimension")
         .SetDefault(0);
     AddAttr("output_dim_idx",
-                 "(int, default 0) the index of output's batch size dimension")
+                 "(int, default 0) The index of output's batch size dimension")
         .SetDefault(0);
     AddAttr("value", "(float, default 0) The value to be filled")
         .SetDefault(0.0f);
diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc
index 5a1cba51f83bb8577bc94ae23d1a44bb801ae4c7..818f113b90a4c239a857791fb9957e51d3287b97 100644
--- a/paddle/operators/fill_constant_op.cc
+++ b/paddle/operators/fill_constant_op.cc
@@ -12,33 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#include "paddle/operators/fill_constant_op.h"
+#include "paddle/framework/data_type.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/operators/math/math_function.h"
 
 namespace paddle {
 namespace operators {
 
-class FillConstantOp : public framework::OperatorWithKernel {
+class FillConstantInferShape : public framework::InferShapeBase {
  public:
-  using framework::OperatorWithKernel::OperatorWithKernel;
-
-  void InferShape(framework::InferShapeContext *ctx) const override {
+  void operator()(framework::InferShapeContext *ctx) const override {
     PADDLE_ENFORCE(ctx->HasOutput("Out"),
                    "Output(Out) of FillConstantOp should not be null.");
     auto &shape = ctx->Attrs().Get>("shape");
-    std::vector shape_int64(shape.size(), 0);
-    std::transform(shape.begin(), shape.end(), shape_int64.begin(),
-                   [](int a) { return static_cast(a); });
-    auto dims = framework::make_ddim(shape_int64);
-    ctx->SetOutputDim("Out", dims);
+    ctx->SetOutputDim("Out", framework::make_ddim(shape));
   }
+};
 
- protected:
-  framework::OpKernelType GetKernelType(
-      const framework::ExecutionContext &ctx) const override {
-    int data_type = ctx.Attr("data_type");
-    VLOG(10) << " FillConstant data_type = " << data_type;
-    return framework::OpKernelType(static_cast(data_type),
-                                   ctx.device_context());
+class FillConstantOp : public framework::OperatorBase {
+ public:
+  using framework::OperatorBase::OperatorBase;
+  void Run(const framework::Scope &scope,
+           const platform::DeviceContext &dev_ctx) const override {
+    auto data_type = static_cast(Attr("data_type"));
+    auto value = Attr("value");
+    auto force_cpu = Attr("force_cpu");
+    auto &out =
+        *scope.FindVar(Output("Out"))->GetMutable();
+    out.Resize(framework::make_ddim(Attr>("shape")));
+    if (force_cpu) {
+      auto cpu = platform::CPUPlace();
+      out.mutable_data(cpu, framework::ToTypeIndex(data_type));
+    } else {
+      out.mutable_data(dev_ctx.GetPlace(), framework::ToTypeIndex(data_type));
+    }
+    math::set_constant(dev_ctx, &out, value);
   }
 };
 
@@ -54,6 +62,11 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
     AddAttr>("shape", "(vector) The shape of the output");
     AddAttr("value", "(float, default 0) The value to be filled")
         .SetDefault(0.0f);
+    AddAttr("force_cpu",
+                  "(bool, default false) Force fill output variable to cpu "
+                  "memory. Otherwise, fill output variable to the running "
+                  "device")
+        .SetDefault(false);
     AddOutput("Out",
               "(Tensor) Tensor of specified shape will be filled "
               "with the specified value");
@@ -69,10 +82,6 @@ Fill up a variable with specified constant value.
 }  // namespace paddle
 
 namespace ops = paddle::operators;
-REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
-                             ops::FillConstantOpMaker);
-REGISTER_OP_CPU_KERNEL(
-    fill_constant, ops::FillConstantOpKernel,
-    ops::FillConstantOpKernel,
-    ops::FillConstantOpKernel,
-    ops::FillConstantOpKernel);
+REGISTER_OPERATOR(fill_constant, ops::FillConstantOp,
+                  ops::FillConstantInferShape, ops::FillConstantOpMaker,
+                  paddle::framework::EmptyGradOpMaker);
diff --git a/paddle/operators/fill_constant_op.h b/paddle/operators/fill_constant_op.h
deleted file mode 100644
index 3668f42f1c29541e29463ff3969064e80703fa04..0000000000000000000000000000000000000000
--- a/paddle/operators/fill_constant_op.h
+++ /dev/null
@@ -1,37 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-#include "paddle/framework/eigen.h"
-#include "paddle/framework/op_registry.h"
-
-namespace paddle {
-namespace operators {
-
-template 
-class FillConstantOpKernel : public framework::OpKernel {
- public:
-  void Compute(const framework::ExecutionContext& ctx) const override {
-    auto* out = ctx.Output("Out");
-    out->mutable_data(ctx.GetPlace());
-    auto value = ctx.Attr("value");
-
-    auto out_eigen = framework::EigenVector::Flatten(*out);
-    auto place = ctx.GetEigenDevice();
-    out_eigen.device(place) = out_eigen.constant(static_cast(value));
-  }
-};
-
-}  // namespace operators
-}  // namespace paddle
diff --git a/paddle/operators/lod_tensor_to_array_op.cc b/paddle/operators/lod_tensor_to_array_op.cc
index 5f02f5e8a12831a33683cdc53cf0feb7cb908da5..58af35564d83b9699af4f7783fb6367ff9590682 100644
--- a/paddle/operators/lod_tensor_to_array_op.cc
+++ b/paddle/operators/lod_tensor_to_array_op.cc
@@ -133,6 +133,22 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
   }
 };
 
+class LoDTensorToArrayGradMaker : public framework::SingleGradOpDescMaker {
+ public:
+  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+  std::unique_ptr Apply() const override {
+    auto *grad_op = new framework::OpDescBind();
+    grad_op->SetType("array_to_lod_tensor");
+    grad_op->SetInput("X", OutputGrad("Out"));
+    grad_op->SetInput("RankTable", Input("RankTable"));
+    grad_op->SetOutput("Out", InputGrad("X"));
+    grad_op->SetAttrMap(Attrs());
+    return std::unique_ptr(grad_op);
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
@@ -140,4 +156,5 @@ namespace ops = paddle::operators;
 REGISTER_OPERATOR(lod_tensor_to_array, ops::LoDTensorToArrayOp,
                   ops::LoDTensorToArrayOpProtoMaker,
                   ops::LoDTensorToArrayInferShape,
-                  ops::LoDTensorToArrayInferVarType);
+                  ops::LoDTensorToArrayInferVarType,
+                  ops::LoDTensorToArrayGradMaker);
diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc
index 6b859dbbe7f760a93133e0cb12b6bd3fc5fd88e0..4cbb60f3fdab968e8c36d4fbad55fd3efc7b1d0d 100644
--- a/paddle/operators/lstm_op.cc
+++ b/paddle/operators/lstm_op.cc
@@ -24,6 +24,11 @@ class LSTMOp : public framework::OperatorWithKernel {
   void InferShape(framework::InferShapeContext* ctx) const override {
     PADDLE_ENFORCE(ctx->HasInput("Input"),
                    "Input(Input) of LSTM should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Weight"),
+                   "Input(Weight) of LSTM should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Bias"),
+                   "Input(Bias) of LSTM should not be null.");
+
     PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
                    "Output(Hidden) of LSTM should not be null.");
     PADDLE_ENFORCE(ctx->HasOutput("Cell"),
@@ -59,11 +64,13 @@ class LSTMOp : public framework::OperatorWithKernel {
                       "The second dimension of Input(Weight) "
                       "should be 4 * %d.",
                       frame_size);
+
     auto b_dims = ctx->GetInputDim("Bias");
     PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
     PADDLE_ENFORCE_EQ(b_dims[0], 1,
                       "The first dimension of Input(Bias) should be 1.");
-    if (ctx->Attrs().Get("usePeepholes")) {
+
+    if (ctx->Attrs().Get("use_peepholes")) {
       PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
                         "The second dimension of Input(Bias) should be "
                         "7 * %d if enable peepholes connection",
@@ -74,6 +81,7 @@ class LSTMOp : public framework::OperatorWithKernel {
                         "4 * %d if disable peepholes connection",
                         frame_size);
     }
+
     framework::DDim out_dims({in_dims[0], frame_size});
     ctx->SetOutputDim("Hidden", out_dims);
     ctx->SetOutputDim("Cell", out_dims);
@@ -118,14 +126,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
     AddInput("Bias",
              "(Tensor) the learnable weights, which contains two parts: "
              "input-hidden bias weight and peephole connections weight if "
-             "setting `usePeepholes` True. "
-             "1. `usePeepholes = False` "
+             "setting `use_peepholes` True. "
+             "1. `use_peepholes = False` "
              " - The shape is (1 x 4D). "
              " - Bias = {b_c, b_i, b_f, b_o}."
-             "2. `usePeepholes = True` "
+             "2. `use_peepholes = True` "
              " - The shape is (1 x 7D). "
-             " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.")
-        .AsDispensable();
+             " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
     AddOutput("Hidden",
               "(LoDTensor) the hidden state of LSTM operator. "
               "The shape is (T x D), and lod is the same with the `Input`.");
@@ -145,29 +152,32 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
               "(LoDTensor) This LoDTensor is obtained in the forward and used "
               "in the backward.")
         .AsIntermediate();
-    AddAttr("usePeepholes",
-                  "(bool, default True) "
+    AddAttr("use_peepholes",
+                  "(bool, defalut: True) "
                   "whether to enable diagonal/peephole connections.")
         .SetDefault(true);
-    AddAttr("isReverse",
-                  "(bool, default False) "
+    AddAttr("is_reverse",
+                  "(bool, defalut: False) "
                   "whether to compute reversed LSTM.")
         .SetDefault(false);
     AddAttr(
-        "gateActivation",
-        "(string, default sigmoid)"
+        "gate_activation",
+        "(string, default: sigmoid)"
         "The activation for input gate, forget gate and output "
         "gate, `sigmoid` by default.")
-        .SetDefault("sigmoid");
-    AddAttr("cellActivation",
-                         "(string, default tanh)"
+        .SetDefault("sigmoid")
+        .InEnum({"sigmoid", "tanh", "relu", "identity"});
+    AddAttr("cell_activation",
+                         "(string, default: tanh)"
                          "The activation for cell output, `tanh` by defalut.")
-        .SetDefault("tanh");
-    AddAttr("candidateActivation",
-                         "(string, default tanh)"
+        .SetDefault("tanh")
+        .InEnum({"sigmoid", "tanh", "relu", "identity"});
+    AddAttr("candidate_activation",
+                         "(string, default: tanh)"
                          "The activation for candidate hidden state, "
                          "`tanh` by default.")
-        .SetDefault("tanh");
+        .SetDefault("tanh")
+        .InEnum({"sigmoid", "tanh", "relu", "identity"});
     AddComment(R"DOC(
 Long-Short Term Memory (LSTM) Operator.
 
@@ -203,7 +213,7 @@ are the cell input and cell output activation functions and `tanh` is usually
 used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
 which is computed based on the current input and the previous hidden state.
 
-Set usePeepholes False to disable peephole connection 
+Set `use_peepholes` False to disable peephole connection 
 (http://www.bioinf.jku.at/publications/older/2604.pdf). The formula
 is omitted here.
 
@@ -226,23 +236,27 @@ class LSTMGradOp : public framework::OperatorWithKernel {
                    "Input(Hidden) of LSTM should not be null.");
     PADDLE_ENFORCE(ctx->HasInput("Cell"),
                    "Input(Cell) of LSTM should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Weight"),
+                   "Input(Weight) of LSTM should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Bias"),
+                   "Input(Bias) of LSTM should not be null.");
 
     PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
                    "Input(BatchGate) of LSTM should not be null.");
     PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
                    "Input(BatchGate) of LSTM should not be null.");
 
-    auto in_g_name = framework::GradVarName("Input");
-    if (ctx->HasOutput(in_g_name))
-      ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input"));
-
-    auto w_g_name = framework::GradVarName("Weight");
-    if (ctx->HasOutput(w_g_name))
-      ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight"));
-
-    auto b_g_name = framework::GradVarName("Bias");
-    if (ctx->HasOutput(b_g_name))
-      ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
+    auto SetOutGradDim = [&ctx](const std::string& name) {
+      auto g_name = framework::GradVarName(name);
+      if (ctx->HasOutput(g_name))
+        ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
+    };
+
+    SetOutGradDim("Input");
+    SetOutGradDim("Weight");
+    SetOutGradDim("Bias");
+    SetOutGradDim("H0");
+    SetOutGradDim("C0");
   }
 
  protected:
diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h
index af088b80b4283cf221a1dff74546d73d977fada3..fca84e2d8fa832a3780eab7e0fa2facceb4d613b 100644
--- a/paddle/operators/lstm_op.h
+++ b/paddle/operators/lstm_op.h
@@ -28,6 +28,15 @@ template 
 using EigenMatrix = framework::EigenMatrix;
 
+template 
+inline void ReorderInitState(const platform::DeviceContext& ctx,
+                             const framework::Tensor& src, const size_t* index,
+                             framework::Tensor* dst, bool indexed_src) {
+  math::CopyMatrixRowsFunctor row_shuffle;
+  dst->mutable_data(src.dims(), ctx.GetPlace());
+  row_shuffle(ctx, src, index, *dst, indexed_src);
+}
+
 template 
 class LSTMKernel : public framework::OpKernel {
  public:
@@ -36,6 +45,9 @@ class LSTMKernel : public framework::OpKernel {
     auto* weight = ctx.Input("Weight");
     auto* bias = ctx.Input("Bias");
 
+    auto* hidden_t0 = ctx.Input("H0");
+    auto* cell_t0 = ctx.Input("C0");
+
     auto* batch_gate = ctx.Output("BatchGate");
     batch_gate->mutable_data(ctx.GetPlace());
     auto* hidden_out = ctx.Output("Hidden");
@@ -43,12 +55,7 @@ class LSTMKernel : public framework::OpKernel {
     auto* cell_out = ctx.Output("Cell");
     cell_out->mutable_data(ctx.GetPlace());
 
-    // Now the function ShareLoD in InferShape is not implemented.
-    // So copy LoD here.
-    ctx.ShareLoD("Input", "Hidden");
-    ctx.ShareLoD("Input", "Cell");
-
-    bool is_reverse = ctx.Attr("isReverse");
+    bool is_reverse = ctx.Attr("is_reverse");
     math::LoDTensor2BatchFunctor to_batch;
     auto& device_ctx = ctx.device_context();
     to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
@@ -71,7 +78,7 @@ class LSTMKernel : public framework::OpKernel {
     }
 
     math::LstmMetaValue lstm_value;
-    if (bias) {
+    if (bias && ctx.Attr("use_peepholes")) {
       T* bias_data = const_cast(bias->data());
       // the code style in LstmMetaValue will be updated later.
 
@@ -84,6 +91,16 @@ class LSTMKernel : public framework::OpKernel {
       lstm_value.checkOg = nullptr;
     }
     lstm_value.prevStateValue = nullptr;
+    Tensor ordered_c0;
+    const size_t* order = batch_gate->lod()[2].data();
+    if (cell_t0) {
+      // Since the batch computing for LSTM reorders the input sequence
+      // according to their length. The initialized cell state also needs
+      // to reorder.
+      ReorderInitState(device_ctx, *cell_t0, order, &ordered_c0,
+                                 true);
+      lstm_value.prevStateValue = ordered_c0.data();
+    }
 
     // Use the local variable as here.
     LoDTensor batch_hidden, batch_cell;
@@ -94,9 +111,9 @@ class LSTMKernel : public framework::OpKernel {
 
     auto batch_starts = batch_gate->lod()[0];
     size_t num_batch = batch_starts.size() - 1;
-    auto gate_act = ctx.Attr("gateActivation");
-    auto cell_act = ctx.Attr("cellActivation");
-    auto cand_act = ctx.Attr("candidateActivation");
+    auto gate_act = ctx.Attr("gate_activation");
+    auto cell_act = ctx.Attr("cell_activation");
+    auto cand_act = ctx.Attr("candidate_activation");
 
     for (size_t n = 0; n < num_batch; n++) {
       int bstart = static_cast(batch_starts[n]);
@@ -109,15 +126,28 @@ class LSTMKernel : public framework::OpKernel {
 
       int cur_batch_size = bend - bstart;
 
-      if (n != 0) {
+      if (n > 0) {
         int pre_h_start = static_cast(batch_starts[n - 1]);
         int pre_h_end = pre_h_start + cur_batch_size;
         auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
         math::matmul(device_ctx, pre_hidden_t, false, *weight, false,
                                static_cast(1.0), &gate_t,
                                static_cast(1.0));
+      } else if (hidden_t0) {
+        // If n == 0 and there is no initialized hidden state, that is to say
+        // the H0 is zeros, the calculation W_h * H0 will be skiped.
+        // If n == 0 and there is initialized hidden state, calculate W_h * H0.
+
+        // Since the batch computing for LSTM reorders the input sequence
+        // according to their length. The initialized hidden state also needs
+        // to reorder.
+        Tensor ordered_h0;
+        ReorderInitState(device_ctx, *hidden_t0, order, &ordered_h0,
+                                   true);
+        math::matmul(device_ctx, ordered_h0, false, *weight, false,
+                               static_cast(1.0), &gate_t,
+                               static_cast(1.0));
       }
-      // else if : FIXME support the initial hidden and cell
 
       lstm_value.gateValue = gate_t.data();
       lstm_value.outputValue = out_t.data();
@@ -160,6 +190,12 @@ class LSTMGradKernel : public framework::OpKernel {
     auto* weight_g = ctx.Output(framework::GradVarName("Weight"));
     auto* bias_g = ctx.Output(framework::GradVarName("Bias"));
 
+    auto* h0 = ctx.Input("H0");
+    auto* c0 = ctx.Input("C0");
+
+    auto* h0_g = ctx.Output(framework::GradVarName("H0"));
+    auto* c0_g = ctx.Output(framework::GradVarName("C0"));
+
     auto& device_ctx = ctx.device_context();
     math::SetConstant zero;
     if (weight_g) {
@@ -167,13 +203,25 @@ class LSTMGradKernel : public framework::OpKernel {
       zero(device_ctx, weight_g, static_cast(0.0));
     }
 
+    // ordered_h0/c0 is the reordered hidden/cell initialization.
+    // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
+    // initialization.
+    Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
+    const size_t* order = batch_gate->lod()[2].data();
+    if (c0) {
+      ReorderInitState(device_ctx, *c0, order, &ordered_c0, true);
+    }
+    if (c0 && c0_g) {
+      ordered_c0_g.mutable_data(c0_g->dims(), ctx.GetPlace());
+    }
+
     auto in_dims = input->dims();
     auto out_dims = hidden_g->dims();
     int frame_size = static_cast(in_dims[1] / 4);
     PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
 
     math::LstmMetaValue lstm_value;
-    if (bias) {
+    if (bias && ctx.Attr("use_peepholes")) {
       T* bias_data = const_cast(bias->data());
       lstm_value.checkIg = bias_data + 4 * frame_size;
       lstm_value.checkFg = lstm_value.checkIg + frame_size;
@@ -185,9 +233,13 @@ class LSTMGradKernel : public framework::OpKernel {
     }
 
     math::LstmMetaGrad lstm_grad;
+
     if (bias && bias_g) {
-      T* bias_g_data = const_cast(bias_g->mutable_data(ctx.GetPlace()));
+      bias_g->mutable_data(ctx.GetPlace());
       zero(device_ctx, bias_g, static_cast(0.0));
+    }
+    if (bias && bias_g && ctx.Attr("use_peepholes")) {
+      T* bias_g_data = bias_g->data();
       lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
       lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
       lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
@@ -199,36 +251,30 @@ class LSTMGradKernel : public framework::OpKernel {
 
     math::LoDTensor2BatchFunctor to_batch;
 
-    // use the local variable as here.
-    LoDTensor batch_hidden;
-    batch_hidden.mutable_data