diff --git a/Dockerfile b/Dockerfile
index 8cfb16928c95dcbfac08383d32562ff67933d873..5dd9b0be4f7e0a304108abfdfb089fea4faa4d38 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -27,7 +27,7 @@ RUN apt-get update && \
     git python-pip python-dev openssh-server bison  \
     wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
     curl sed grep graphviz libjpeg-dev zlib1g-dev  \
-    python-numpy python-matplotlib gcc g++ \
+    python-numpy python-matplotlib gcc-4.8 g++-4.8 \
     automake locales clang-format-3.8 swig doxygen cmake  \
     liblapack-dev liblapacke-dev libboost-dev \
     clang-3.8 llvm-3.8 libclang-3.8-dev \
diff --git a/cmake/flags.cmake b/cmake/flags.cmake
index ef31c252038ce18655913c0f41343fe6dc7dbb86..d00a9bb3a30cfb16623e073414088059481c3e1a 100644
--- a/cmake/flags.cmake
+++ b/cmake/flags.cmake
@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
         if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
             message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
         endif()
+        # TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
+        # Use Debug mode instead for now.
+        if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9) 
+            set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
+        endif()
     elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
         # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
         # Apple Clang is a different compiler than upstream Clang which havs different version numbers.
diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst
index ec7f1446cfb74842af7d0c7152bebf58619f3861..372272a53c12c314fc80eebbce5eae9fcabc55ba 100644
--- a/doc/api/v2/config/layer.rst
+++ b/doc/api/v2/config/layer.rst
@@ -104,6 +104,11 @@ cross_channel_norm
 ------------------
 ..  autoclass:: paddle.v2.layer.cross_channel_norm
     :noindex:
+
+row_l2_norm
+-----------
+..  autoclass:: paddle.v2.layer.row_l2_norm
+    :noindex:
     
 Recurrent Layers
 ================
@@ -320,6 +325,11 @@ scaling
 ..  autoclass:: paddle.v2.layer.scaling
     :noindex:
 
+clip
+----
+..  autoclass:: paddle.v2.layer.clip
+    :noindex:
+
 slope_intercept
 ---------------
 ..  autoclass:: paddle.v2.layer.slope_intercept
diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc
index c53a5636829cab9d575f58cc2326cb3efe383e1c..7ad8a39768a064140a08c912a5a467bc24a12adf 100644
--- a/paddle/cuda/src/hl_cuda_cudnn.cc
+++ b/paddle/cuda/src/hl_cuda_cudnn.cc
@@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
   real alpha = 1.0f;
   real beta = 1.0f;
   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
+
+  int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size;
+  if (batch_size > 1024 && g_cudnn_lib_version < 6000) {
+    LOG(INFO) << " To process current batch data with size " << batch_size
+              << " (>1024), cudnnBatchNorm requires cuDNN version >= 6000."
+              << " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED,"
+              << " just recompile PaddlePaddle with cuDNN >= 6000, replacing"
+              << " current version " << g_cudnn_lib_version;
+  }
   CHECK_CUDNN(
       dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle,
                                                        mode,
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index b73426eaacdf2eaf115b4ac02d58e02d24cc753d..f8f9bae12d42ccbc52b1046900d239ae0cde6940 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -38,7 +38,7 @@ cc_library(backward SRCS backward.cc DEPS net)
 cc_test(backward_test SRCS backward_test.cc DEPS backward)
 cc_library(paddle_pybind SHARED
     SRCS pybind.cc
-    DEPS pybind python
+    DEPS pybind python backward
 	fc_op
 	sgd_op
 	add_op
diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h
index f10c9297981a4c6aefc6c2072d0ac2b8e562a7a0..3e72e391266066de9e4114e68b43b066c15254db 100644
--- a/paddle/framework/op_registry.h
+++ b/paddle/framework/op_registry.h
@@ -400,6 +400,14 @@ class GradOpRegisterHelper {
     return 0;                                                                  \
   }
 
+/**
+ * Macro to Forbid user register Gradient Operator.
+ */
+#define NO_GRADIENT(__op_type)                          \
+  STATIC_ASSERT_GLOBAL_NAMESPACE(                       \
+      __reg_gradient_op__##__op_type##__op_type##_grad, \
+      "NO_GRADIENT must be in global namespace")
+
 /**
  * Macro to Register OperatorKernel.
  */
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index cfe9cba308556475ef64b45e7178dfc418761598..cb86e6be2be3624bf54ee28193ca5d4c7bafa0eb 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -20,16 +20,16 @@ namespace paddle {
 namespace framework {
 
 template <>
-Eigen::DefaultDevice* ExecutionContext::GetEigenDevice<
+Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
     platform::CPUPlace, Eigen::DefaultDevice>() const {
-  return device_context_.get_eigen_device<Eigen::DefaultDevice>();
+  return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
 }
 
 #ifndef PADDLE_ONLY_CPU
 template <>
-Eigen::GpuDevice*
+Eigen::GpuDevice&
 ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
-  return device_context_.get_eigen_device<Eigen::GpuDevice>();
+  return *device_context_.get_eigen_device<Eigen::GpuDevice>();
 }
 #endif
 
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 0832a663dd01fe2921366d70599bc867e73af47c..55435103489ace11868eed61c38018d8ba357e65 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext {
   template <typename PlaceType,
             typename DeviceType =
                 typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
-  DeviceType* GetEigenDevice() const;
+  DeviceType& GetEigenDevice() const;
 
   platform::Place GetPlace() const { return device_context_.GetPlace(); }
 
diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc
index a735cc2ad51aaf3eaa2ad05f2ab757448b31ed49..cc47469b4db53458f6a4314f4339b58a9527637e 100644
--- a/paddle/framework/pybind.cc
+++ b/paddle/framework/pybind.cc
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
-    http://www.apache.org/licenses/LICENSE-2.0
+http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
@@ -16,11 +16,14 @@ limitations under the License. */
 #include <fstream>
 #include <vector>
 
+#include "paddle/framework/backward.h"
 #include "paddle/framework/net.h"
 #include "paddle/framework/op_registry.h"
 #include "paddle/framework/operator.h"
 #include "paddle/framework/scope.h"
 #include "paddle/framework/tensor_bind.h"
+#include "paddle/platform/enforce.h"
+#include "paddle/platform/place.h"
 #include "pybind11/numpy.h"
 #include "pybind11/pybind11.h"
 #include "pybind11/stl.h"
@@ -43,6 +46,10 @@ template <typename ClassType>
 void ExposeOperator(ClassType &m) {
   m.def("infer_shape", &ClassType::type::InferShape)
       .def("run", &ClassType::type::Run)
+      .def("type",
+           [](const typename ClassType::type &op) -> std::string {
+             return op.type_;
+           })
       .def("outputs",
            [](const typename ClassType::type &op) -> std::vector<std::string> {
              return op.outputs_;
@@ -55,6 +62,14 @@ static size_t UniqueIntegerGenerator() {
   return generator.fetch_add(1);
 }
 
+bool IsCompileGPU() {
+#ifdef PADDLE_ONLY_CPU
+  return false;
+#else
+  return true;
+#endif
+}
+
 PYBIND11_PLUGIN(core) {
   py::module m("core", "C++ core of PaddlePaddle");
 
@@ -68,16 +83,29 @@ PYBIND11_PLUGIN(core) {
              self.Resize(make_ddim(dim));
            })
       .def("alloc_float",
-           [](Tensor &self) {
-             self.mutable_data<float>(paddle::platform::CPUPlace());
+           [](pd::Tensor &self, paddle::platform::GPUPlace &place) {
+             self.mutable_data<float>(place);
+           })
+      .def("alloc_float",
+           [](pd::Tensor &self, paddle::platform::CPUPlace &place) {
+             self.mutable_data<float>(place);
            })
       .def("alloc_int",
-           [](Tensor &self) {
-             self.mutable_data<int>(paddle::platform::CPUPlace());
+           [](pd::Tensor &self, paddle::platform::CPUPlace &place) {
+             self.mutable_data<int>(place);
            })
-      .def("set", PyTensorSetFromArray<float>)
-      .def("set", PyTensorSetFromArray<int>)
-      .def("shape", [](Tensor &self) { return vectorize(self.dims()); });
+      .def("alloc_int",
+           [](pd::Tensor &self, paddle::platform::GPUPlace &place) {
+             self.mutable_data<int>(place);
+           })
+      .def("set", paddle::pybind::PyCPUTensorSetFromArray<float>)
+      .def("set", paddle::pybind::PyCPUTensorSetFromArray<int>)
+#ifndef PADDLE_ONLY_CPU
+      .def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
+      .def("set", paddle::pybind::PyCUDATensorSetFromArray<int>)
+#endif
+      .def("shape",
+           [](pd::Tensor &self) { return pd::vectorize(self.dims()); });
 
   py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
 
@@ -124,13 +152,29 @@ All parameter, weight, gradient are variables in Paddle.
   m.def_submodule(
        "var_names",
        "The module will return special predefined variable name in Paddle")
-      .def("empty", OperatorBase::EMPTY_VAR_NAME)
-      .def("temp", OperatorBase::TMP_VAR_NAME);
-
+      .def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
+      .def("temp", pd::OperatorBase::TMP_VAR_NAME);
+  // clang-format off
   py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
-      .def_static("cpu_context", []() -> paddle::platform::DeviceContext * {
-        return new paddle::platform::CPUDeviceContext();
-      });
+      .def_static("create",
+                  [](paddle::platform::CPUPlace& place)
+                      -> paddle::platform::DeviceContext* {
+                    return new paddle::platform::CPUDeviceContext();
+                  })
+      .def_static("create",
+                  [](paddle::platform::GPUPlace& place)
+                      -> paddle::platform::DeviceContext* {
+#ifdef PADDLE_ONLY_CPU
+                    PADDLE_THROW("GPUPlace is not supported in CPU device.");
+#else
+                    return new paddle::platform::CUDADeviceContext(place);
+#endif
+                  });
+  // clang-format on
+
+  py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
+
+  py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
 
   py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
       m, "Operator");
@@ -144,6 +188,13 @@ All parameter, weight, gradient are variables in Paddle.
                    desc.InitializationErrorString());
     return OpRegistry::CreateOp(desc);
   });
+
+  operator_base.def("backward",
+                    [](const pd::OperatorBase &forwardOp,
+                       const std::unordered_set<std::string> &no_grad_vars) {
+                      return pd::Backward(forwardOp, no_grad_vars);
+                    });
+
   ExposeOperator(operator_base);
 
   py::class_<NetOp, std::shared_ptr<NetOp>> net(m, "Net");
@@ -166,6 +217,8 @@ All parameter, weight, gradient are variables in Paddle.
 
   m.def("unique_integer", UniqueIntegerGenerator);
 
+  m.def("is_compile_gpu", IsCompileGPU);
+
   return m.ptr();
 }
 }  // namespace framework
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index c3e9a914f1d389b058d380b893441e34249f4293..4c3b14b83d841e88683a13634c93f51c012128b6 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -165,4 +165,4 @@ class Tensor {
 }  // namespace framework
 }  // namespace paddle
 
-#include "paddle/framework/detail/tensor-inl.h"
+#include "paddle/framework/tensor_impl.h"
diff --git a/paddle/framework/tensor_bind.h b/paddle/framework/tensor_bind.h
index 530b640f7051db2334c873bf4cd9608fcc0e88f1..4e1ab77b157fe1adaeac55c271c056236f2d40de 100644
--- a/paddle/framework/tensor_bind.h
+++ b/paddle/framework/tensor_bind.h
@@ -13,9 +13,11 @@
    limitations under the License. */
 
 #pragma once
-#include <paddle/framework/tensor.h>
-#include <pybind11/numpy.h>
-#include <pybind11/pybind11.h>
+#include <string>
+#include "paddle/framework/tensor.h"
+#include "paddle/memory/memcpy.h"
+#include "pybind11/numpy.h"
+#include "pybind11/pybind11.h"
 
 namespace py = pybind11;
 
@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS>
 struct CastToPyBufferImpl<true, I, ARGS...> {
   using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
   py::buffer_info operator()(framework::Tensor &tensor) {
-    PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
-                   "Only CPU tensor can cast to numpy array");
-
     if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
       auto dim_vec = framework::vectorize(tensor.dims());
       std::vector<size_t> dims_outside;
@@ -56,11 +55,16 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
         strides[i - 1] = sizeof(CUR_TYPE) * prod;
         prod *= dims_outside[i - 1];
       }
-
+      framework::Tensor dst_tensor;
+      if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
+        dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
+      } else if (paddle::platform::is_cpu_place(tensor.holder_->place())) {
+        dst_tensor = tensor;
+      }
       return py::buffer_info(
-          tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
+          dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
           sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
-          (size_t)framework::arity(tensor.dims()), dims_outside, strides);
+          (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
     } else {
       constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
       return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
@@ -74,9 +78,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
 }
 
 template <typename T>
-void PyTensorSetFromArray(
+void PyCPUTensorSetFromArray(
     framework::Tensor &self,
-    py::array_t<T, py::array::c_style | py::array::forcecast> array) {
+    py::array_t<T, py::array::c_style | py::array::forcecast> array,
+    paddle::platform::CPUPlace &place) {
   std::vector<int> dims;
   dims.reserve(array.ndim());
   for (size_t i = 0; i < array.ndim(); ++i) {
@@ -84,9 +89,28 @@ void PyTensorSetFromArray(
   }
 
   self.Resize(framework::make_ddim(dims));
-  auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace());
+  auto *dst = self.mutable_data<T>(place);
   std::memcpy(dst, array.data(), sizeof(T) * array.size());
 }
 
+#ifndef PADDLE_ONLY_CPU
+template <typename T>
+void PyCUDATensorSetFromArray(
+    framework::Tensor &self,
+    py::array_t<T, py::array::c_style | py::array::forcecast> array,
+    paddle::platform::GPUPlace &place) {
+  std::vector<int> dims;
+  dims.reserve(array.ndim());
+  for (size_t i = 0; i < array.ndim(); ++i) {
+    dims.push_back((int)array.shape()[i]);
+  }
+
+  self.Resize(framework::make_ddim(dims));
+  auto *dst = self.mutable_data<T>(place);
+  paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(),
+                                  cudaMemcpyHostToDevice);
+}
+#endif
+
 }  // namespace pybind
 }  // namespace paddle
diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/tensor_impl.h
similarity index 97%
rename from paddle/framework/detail/tensor-inl.h
rename to paddle/framework/tensor_impl.h
index e7ff09dd5c954378afeca299e901277c3ebdb96a..92621f8c18ec0d03160a23c462830d14272c7f64 100644
--- a/paddle/framework/detail/tensor-inl.h
+++ b/paddle/framework/tensor_impl.h
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #pragma once
-
 #include "paddle/memory/memcpy.h"
 
 namespace paddle {
@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
     if (platform::is_cpu_place(place)) {
       holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
           boost::get<platform::CPUPlace>(place), size));
+    } else if (platform::is_gpu_place(place)) {
+#ifdef PADDLE_ONLY_CPU
+      PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
     }
-#ifndef PADDLE_ONLY_CPU
-    else if (platform::is_gpu_place(place)) {
+#else
       holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
           boost::get<platform::GPUPlace>(place), size));
     }
diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h
index bb4f48364b9b454af7d37fe4d3c340666e53285c..baf78bc6c88d0d294f4457b81c52b22e425d9fdb 100644
--- a/paddle/function/ConvOp.h
+++ b/paddle/function/ConvOp.h
@@ -109,6 +109,13 @@ protected:
     return filter[filter.ndims() - 1];
   }
 
+  // determine whether im2col needs to be performed
+  inline bool isNeedIm2col(const TensorShape& filter) const {
+    return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 &&
+             strideH() == 1 && strideW() == 1 && paddingH() == 0 &&
+             paddingW() == 0);
+  }
+
   std::vector<size_t> strides_;
   std::vector<size_t> paddings_;
 
diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp
index 9deb2739fcfff935a98a0b5b31b5d11819d81227..0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87 100644
--- a/paddle/function/GemmConvOp.cpp
+++ b/paddle/function/GemmConvOp.cpp
@@ -66,16 +66,23 @@ public:
     real* inputData = inputs[0].data<real>();
     real* filterData = inputs[1].data<real>();
     real* outputData = outputs[0].data<real>();
+    bool needIm2col = isNeedIm2col(filter);
+
     TensorShape imShape =
         TensorShape({inputChannels / groups_, inputHeight, inputWidth});
-    TensorShape colShape = TensorShape({inputChannels / groups_,
-                                        filterHeight,
-                                        filterWidth,
-                                        outputHeight,
-                                        outputWidth});
 
-    resizeBuffer<Device>(colShape.getElements());
-    real* colData = reinterpret_cast<real*>(memory_->getBuf());
+    TensorShape colShape;
+    real* colData = NULL;
+
+    if (needIm2col) {
+      colShape = TensorShape({inputChannels / groups_,
+                              filterHeight,
+                              filterWidth,
+                              outputHeight,
+                              outputWidth});
+      resizeBuffer<Device>(colShape.getElements());
+      colData = reinterpret_cast<real*>(memory_->getBuf());
+    }
 
     Im2ColFunctor<kCFO, Device, real> im2col;
     GemmFunctor<Device, real> gemm;
@@ -86,15 +93,18 @@ public:
 
     for (size_t i = 0; i < batchSize; i++) {
       for (size_t g = 0; g < groups_; g++) {
-        im2col(inputData + g * inputOffset,
-               imShape,
-               colData,
-               colShape,
-               strideH(),
-               strideW(),
-               paddingH(),
-               paddingW());
-
+        if (needIm2col) {
+          im2col(inputData + g * inputOffset,
+                 imShape,
+                 colData,
+                 colShape,
+                 strideH(),
+                 strideW(),
+                 paddingH(),
+                 paddingW());
+        } else {
+          colData = inputData + g * inputOffset;
+        }
         int M = outputChannels / groups_;
         int N = outputHeight * outputWidth;
         int K = inputChannels / groups_ * filterHeight * filterWidth;
@@ -159,19 +169,27 @@ public:
     real* outputGrad = inputs[0].data<real>();
     real* filterData = inputs[1].data<real>();
     real* inputGrad = outputs[0].data<real>();
+    bool needIm2col = isNeedIm2col(filter);
+
     TensorShape imShape =
         TensorShape({inputChannels / groups_, inputHeight, inputWidth});
-    TensorShape colShape = TensorShape({inputChannels / groups_,
-                                        filterHeight,
-                                        filterWidth,
-                                        outputHeight,
-                                        outputWidth});
 
-    resizeBuffer<Device>(colShape.getElements());
-    real* colData = reinterpret_cast<real*>(memory_->getBuf());
+    TensorShape colShape;
+    real* colData = NULL;
+
+    if (needIm2col) {
+      colShape = TensorShape({inputChannels / groups_,
+                              filterHeight,
+                              filterWidth,
+                              outputHeight,
+                              outputWidth});
+      resizeBuffer<Device>(colShape.getElements());
+      colData = reinterpret_cast<real*>(memory_->getBuf());
+    }
 
     Col2ImFunctor<kCFO, Device, real> col2im;
     GemmFunctor<Device, real> gemm;
+
     size_t inputOffset = imShape.getElements();
     size_t outputOffset =
         (outputChannels / groups_) * outputHeight * outputWidth;
@@ -182,6 +200,11 @@ public:
         int K = outputChannels / groups_;
         int N = outputHeight * outputWidth;
         int M = inputChannels / groups_ * filterHeight * filterWidth;
+        real scale = 0.0f;
+        if (!needIm2col) {
+          colData = inputGrad + g * inputOffset;
+          scale = 1.0f;
+        }
         gemm(CblasTrans,
              CblasNoTrans,
              M,
@@ -192,17 +215,19 @@ public:
              M,
              outputGrad + g * outputOffset,
              N,
-             0.0f,
+             scale,
              colData,
              N);
-        col2im(inputGrad + g * inputOffset,
-               imShape,
-               colData,
-               colShape,
-               strideH(),
-               strideW(),
-               paddingH(),
-               paddingW());
+        if (needIm2col) {
+          col2im(inputGrad + g * inputOffset,
+                 imShape,
+                 colData,
+                 colShape,
+                 strideH(),
+                 strideW(),
+                 paddingH(),
+                 paddingW());
+        }
       }
       inputGrad += inputChannels * inputHeight * inputWidth;
       outputGrad += outputChannels * outputHeight * outputWidth;
@@ -255,16 +280,23 @@ public:
     real* outputGrad = inputs[0].data<real>();
     real* inputData = inputs[1].data<real>();
     real* filterGrad = outputs[0].data<real>();
+    bool needIm2col = isNeedIm2col(filter);
+
     TensorShape imShape =
         TensorShape({inputChannels / groups_, inputHeight, inputWidth});
-    TensorShape colShape = TensorShape({inputChannels / groups_,
-                                        filterHeight,
-                                        filterWidth,
-                                        outputHeight,
-                                        outputWidth});
 
-    resizeBuffer<Device>(colShape.getElements());
-    real* colData = reinterpret_cast<real*>(memory_->getBuf());
+    TensorShape colShape;
+    real* colData = NULL;
+
+    if (needIm2col) {
+      colShape = TensorShape({inputChannels / groups_,
+                              filterHeight,
+                              filterWidth,
+                              outputHeight,
+                              outputWidth});
+      resizeBuffer<Device>(colShape.getElements());
+      colData = reinterpret_cast<real*>(memory_->getBuf());
+    }
 
     Im2ColFunctor<kCFO, Device, real> im2col;
     GemmFunctor<Device, real> gemm;
@@ -274,15 +306,18 @@ public:
     size_t filterOffset = filter.getElements() / groups_;
     for (size_t i = 0; i < batchSize; i++) {
       for (size_t g = 0; g < groups_; g++) {
-        im2col(inputData + g * inputOffset,
-               imShape,
-               colData,
-               colShape,
-               strideH(),
-               strideW(),
-               paddingH(),
-               paddingW());
-
+        if (needIm2col) {
+          im2col(inputData + g * inputOffset,
+                 imShape,
+                 colData,
+                 colShape,
+                 strideH(),
+                 strideW(),
+                 paddingH(),
+                 paddingW());
+        } else {
+          colData = inputData + g * inputOffset;
+        }
         int M = outputChannels / groups_;
         int K = outputHeight * outputWidth;
         int N = inputChannels / groups_ * filterHeight * filterWidth;
diff --git a/paddle/gserver/layers/ClipLayer.cpp b/paddle/gserver/layers/ClipLayer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..13f16c953793b82183237188b56eb61d76ecd2fd
--- /dev/null
+++ b/paddle/gserver/layers/ClipLayer.cpp
@@ -0,0 +1,79 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "Layer.h"
+
+namespace paddle {
+
+/**
+ * A layer for clipping the input value by the threshold.
+ * \f[
+ *   out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right)
+ * \f]
+ */
+
+class ClipLayer : public Layer {
+protected:
+  double min_;
+  double max_;
+
+public:
+  explicit ClipLayer(const LayerConfig& config) : Layer(config) {}
+
+  bool init(const LayerMap& layerMap,
+            const ParameterMap& parameterMap) override;
+
+  void forward(PassType passType) override;
+  void backward(const UpdateCallback& callback = nullptr) override;
+};
+
+REGISTER_LAYER(clip, ClipLayer);
+
+bool ClipLayer::init(const LayerMap& layerMap,
+                     const ParameterMap& parameterMap) {
+  Layer::init(layerMap, parameterMap);
+
+  CHECK_EQ(inputLayers_.size(), 1U);
+  auto layerConf = config_.inputs(0).clip_conf();
+  min_ = layerConf.min();
+  max_ = layerConf.max();
+  CHECK_LT(min_, max_);
+  return true;
+}
+
+void ClipLayer::forward(PassType passType) {
+  Layer::forward(passType);
+
+  MatrixPtr inV = getInputValue(0);
+  resetOutput(inV->getHeight(), inV->getWidth());
+  MatrixPtr outV = getOutputValue();
+  outV->copyFrom(*inV);
+  outV->clip(min_, max_);
+}
+
+void ClipLayer::backward(const UpdateCallback& callback) {
+  MatrixPtr inV = getInputValue(0);
+  MatrixPtr inG = getInputGrad(0);
+  if (inG) {
+    MatrixPtr outV = getOutputValue();
+    MatrixPtr outG = getOutputGrad();
+    MatrixPtr tmpMtx;
+    Matrix::resizeOrCreate(
+        tmpMtx, outG->getHeight(), outG->getWidth(), false, useGpu_);
+    tmpMtx->clipDerivative(*inV, min_, max_);
+    inG->addDotMul(*outG, *tmpMtx, 1, 1);
+  }
+}
+
+}  // namespace paddle
diff --git a/paddle/gserver/layers/RowL2NormLayer.cpp b/paddle/gserver/layers/RowL2NormLayer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0d609be43b73a86d0d0f7b60be993836e2ea6fff
--- /dev/null
+++ b/paddle/gserver/layers/RowL2NormLayer.cpp
@@ -0,0 +1,98 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "Layer.h"
+
+namespace paddle {
+
+/**
+ * A layer for L2 normalization in each row,
+ * \f[
+ *   out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}}
+ * \f]
+ * where the size of \f$in\f$ is (batchSize x dataDim),
+ * and the size of \f$out\f$ is (batchSize x dataDim).
+ */
+
+class RowL2NormLayer : public Layer {
+protected:
+  MatrixPtr inSquare_;
+  MatrixPtr l2NormReciprocal_;
+  MatrixPtr dotSum_;
+
+public:
+  explicit RowL2NormLayer(const LayerConfig& config) : Layer(config) {}
+
+  bool init(const LayerMap& layerMap,
+            const ParameterMap& parameterMap) override;
+
+  void forward(PassType passType) override;
+  void backward(const UpdateCallback& callback = nullptr) override;
+};
+
+REGISTER_LAYER(row_l2_norm, RowL2NormLayer);
+
+bool RowL2NormLayer::init(const LayerMap& layerMap,
+                          const ParameterMap& parameterMap) {
+  Layer::init(layerMap, parameterMap);
+
+  CHECK_EQ(inputLayers_.size(), 1U);
+
+  return true;
+}
+
+void RowL2NormLayer::forward(PassType passType) {
+  Layer::forward(passType);
+
+  MatrixPtr inV = getInputValue(0);
+
+  /* malloc memory for the output_ if necessary */
+  size_t batchSize = inV->getHeight();
+  size_t dataDim = getSize();
+  CHECK_EQ(dataDim, inV->getWidth());
+  resetOutput(batchSize, dataDim);
+  MatrixPtr outV = getOutputValue();
+
+  Matrix::resizeOrCreate(inSquare_, batchSize, dataDim, false, useGpu_);
+  inV->square2(*inSquare_);
+  Matrix::resizeOrCreate(l2NormReciprocal_, batchSize, 1, false, useGpu_);
+  inSquare_->rowSum(*l2NormReciprocal_);
+  l2NormReciprocal_->sqrt2(*l2NormReciprocal_);
+  l2NormReciprocal_->scalarDiv(*l2NormReciprocal_, 1.0);
+  outV->rowScale(0, *inV, *l2NormReciprocal_);
+}
+
+void RowL2NormLayer::backward(const UpdateCallback& callback) {
+  MatrixPtr inV = getInputValue(0);
+  MatrixPtr inG = getInputGrad(0);
+  MatrixPtr outV = getOutputValue();
+  MatrixPtr outG = getOutputGrad();
+  size_t batchSize = inV->getHeight();
+
+  // inG[ij] += outG[ij] / l2NormReciprocal
+  // inG[ij] += -inV[ij] * l2NormReciprocal * l2NormReciprocal * DotMul(outG[i],
+  // inV[i])
+  if (inG) {
+    Matrix::resizeOrCreate(dotSum_, batchSize, 1, false, useGpu_);
+    dotSum_->zeroMem();
+    dotSum_->rowDotMul(0, *outG, *outV);
+    dotSum_->dotMul(*dotSum_, *l2NormReciprocal_);
+    dotSum_->dotMul(*dotSum_, *l2NormReciprocal_);
+    inSquare_->rowScale(0, *inV, *dotSum_);
+    inG->sub(*inSquare_);
+    inG->addRowScale(0, *outG, *l2NormReciprocal_);
+  }
+}
+
+}  // namespace paddle
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index 8ce8600c6743779899b2685c1c12053922265411..fe11278f41c0118ee0bdb34f17fbf9602e0fa76b 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1899,6 +1899,36 @@ TEST(Layer, CropLayer) {
   }
 }
 
+TEST(Layer, ClipLayer) {
+  const size_t batchSize = 128;
+  const size_t size = 512;
+  TestConfig config;
+  config.layerConfig.set_type("clip");
+  config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
+  LayerInputConfig* input = config.layerConfig.add_inputs();
+  ClipConfig* layerConf = input->mutable_clip_conf();
+  double p1 = std::rand() / (double)RAND_MAX;
+  double p2 = std::rand() / (double)RAND_MAX;
+  layerConf->set_min(std::min(p1, p2));
+  layerConf->set_max(std::max(p1, p2));
+  for (auto useGpu : {false, true}) {
+    testLayerGrad(config, "clip", batchSize, false, useGpu, false);
+  }
+}
+
+TEST(Layer, RowL2NormLayer) {
+  const size_t batchSize = 128;
+  const size_t size = 512;
+  TestConfig config;
+  config.layerConfig.set_type("row_l2_norm");
+  config.layerConfig.set_size(size);
+  config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
+  config.layerConfig.add_inputs();
+  for (auto useGpu : {false, true}) {
+    testLayerGrad(config, "row_l2_norm", batchSize, false, useGpu, false);
+  }
+}
+
 int main(int argc, char** argv) {
   testing::InitGoogleTest(&argc, argv);
   initMain(argc, argv);
diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu
index de48b6fac9c7d8125a552022c52353ef6bcef995..6db5965789b3750f46731f157167150583130d0a 100644
--- a/paddle/math/BaseMatrix.cu
+++ b/paddle/math/BaseMatrix.cu
@@ -442,6 +442,12 @@ DEFINE_MATRIX_UNARY_PARAMETER_OP(Clip, TWO_PARAMETER,
 template<class T>
 void BaseMatrixT<T>::clip(T p1, T p2) { applyUnary(unary::Clip<T>(p1, p2)); }
 
+DEFINE_MATRIX_BINARY_PARAMETER_OP(ClipDerivative, TWO_PARAMETER, a = b < p1 ? 0 : (b > p2 ? 0 : 1));
+template<class T>
+void BaseMatrixT<T>::clipDerivative(BaseMatrixT& b, T p1, T p2) {
+  applyBinary(binary::ClipDerivative<T>(p1, p2), b);
+}
+
 DEFINE_MATRIX_UNARY_PARAMETER_OP(BiggerThanScalar, ONE_PARAMETER,
                                  a = a > p ? 1.0f : 0.0f);
 template<class T>
diff --git a/paddle/math/BaseMatrix.h b/paddle/math/BaseMatrix.h
index 120d69f718b954925438fbd2119d69f0be13b3e9..12ad2d45a0bbff182e78da6efb3c5ff4c6b59b55 100644
--- a/paddle/math/BaseMatrix.h
+++ b/paddle/math/BaseMatrix.h
@@ -488,6 +488,13 @@ public:
    */
   void clip(T p1, T p2);
 
+  /**
+   * this = b < low ? 0 : 1
+   *
+   * this = b > high ? 0 : 1
+   */
+  void clipDerivative(BaseMatrixT& b, T p1, T p2);
+
   /**
    * @code
    * a = a > p ? 1.0f : 0.0f
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index b910bee836ed488aeb34f28d0503b5efba396583..6465deeec93100f0238ac850b92f7f7c5a60b795 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -60,10 +60,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
 op_library(fc_op
     SRCS fc_op.cc
     DEPS mul_op rowwise_add_op sigmoid_op softmax_op net)
-
-op_library(recurrent_network_op
-    SRCS recurrent_network_op.cc
-    DEPS op_desc tensor net)
-cc_test(recurrent_network_op_test
-    SRCS recurrent_network_op_test.cc
-    DEPS recurrent_network_op mul_op add_op)
+op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net)
+cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc
index 3a43dbfbada87e458109d8ca22effdb4407b4c1d..85269a5f7445a1745d9be68417789e33eb725d5c 100644
--- a/paddle/operators/add_op.cc
+++ b/paddle/operators/add_op.cc
@@ -50,10 +50,6 @@ The equation is: Out = X + Y
 class AddOpGrad : public OperatorWithKernel {
 protected:
   void InferShape(const InferShapeContext &ctx) const override {}
-  std::string DebugString() const override {
-    LOG(INFO) << "AddOpGrad";
-    return "";
-  }
 };
 
 }  // namespace operators
diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu
index 79d8de6cd46e1c72b14b0554c7be7b4eee281f4c..f961b37565f400b5c26844b9e7a3cff5e682340b 100644
--- a/paddle/operators/add_op.cu
+++ b/paddle/operators/add_op.cu
@@ -1,3 +1,4 @@
+#define EIGEN_USE_GPU
 #include "paddle/framework/op_registry.h"
 #include "paddle/operators/add_op.h"
 
diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h
index d2b649fcbd1e5cac1c8cfcfd4e522e41135f7d1f..54d2231425293f6cfb3adc9cb34d903a75fcdcd0 100644
--- a/paddle/operators/add_op.h
+++ b/paddle/operators/add_op.h
@@ -28,10 +28,13 @@ public:
 
     output->mutable_data<T>(context.GetPlace());
 
-    EigenVector<T>::Flatten(*output).device(
-        *(context.GetEigenDevice<Place>())) =
-        framework::EigenVector<T>::Flatten(*input0) +
-        framework::EigenVector<T>::Flatten(*input1);
+    auto X = EigenVector<T>::Flatten(*input0);
+    auto Y = EigenVector<T>::Flatten(*input1);
+    auto Z = EigenVector<T>::Flatten(*output);
+
+    auto place = context.GetEigenDevice<Place>();
+
+    Z.device(place) = X + Y;
   }
 };
 
diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu
index 19e4b74596a0f59edd04db830ec6f6f481373465..926a0c616b957d8e542c1f3dee227a718fb29f07 100644
--- a/paddle/operators/cross_entropy_op.cu
+++ b/paddle/operators/cross_entropy_op.cu
@@ -1,3 +1,4 @@
+#define EIGEN_USE_GPU
 #include "paddle/operators/cross_entropy_op.h"
 
 REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc
index fe34d6ad4015620cac520146850e10563d4c50e0..78131b26808b183ee107313374493ae870f1b641 100644
--- a/paddle/operators/mean_op.cc
+++ b/paddle/operators/mean_op.cc
@@ -33,13 +33,23 @@ public:
   MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
       : OpProtoAndCheckerMaker(proto, op_checker) {
     AddInput("X", "The input of mean op");
-    AddOutput("Out", "The output of mean op");
+    AddOutput("Out", "The output of mean op").IgnoreGradient();
     AddComment("Mean Operator");
   }
 };
 
+class MeanGradOp : public OperatorWithKernel {
+protected:
+  void InferShape(const InferShapeContext &ctx) const override {
+    ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX())
+        ->Resize(ctx.Input<Tensor>("X")->dims());
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
 REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
 REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
+REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
+REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>);
diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu
index 740157cbc57a64cafcf109186c630691620f542b..e15de2fd0dd84e4015ee0e3b5343d7651b027a88 100644
--- a/paddle/operators/mean_op.cu
+++ b/paddle/operators/mean_op.cu
@@ -3,3 +3,4 @@
 #include "paddle/operators/mean_op.h"
 
 REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);
+REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::GPUPlace, float>);
\ No newline at end of file
diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h
index 5f7d443751d1cdd7de3b67b0de2758ba1d566fb3..a89cb422f9b296dba6eb5358043f73d00aefc5d3 100644
--- a/paddle/operators/mean_op.h
+++ b/paddle/operators/mean_op.h
@@ -27,8 +27,28 @@ public:
 
     output->mutable_data<T>(context.GetPlace());
 
-    EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
-        EigenVector<T>::Flatten(*input).mean();
+    auto X = EigenVector<T>::Flatten(*input);
+    auto y = EigenScalar<T>::From(*output);
+    auto place = context.GetEigenDevice<Place>();
+
+    y.device(place) = X.mean();
+  }
+};
+
+template <typename Place, typename T>
+class MeanGradKernel : public OpKernel {
+public:
+  void Compute(const ExecutionContext& context) const override {
+    auto OG = context.Input<Tensor>("Out" + OperatorBase::GRAD_VAR_SUFFIX());
+    PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
+                   "Mean Gradient should be scalar");
+    auto IG = context.Output<Tensor>("X" + OperatorBase::GRAD_VAR_SUFFIX());
+    IG->mutable_data<T>(context.GetPlace());
+
+    T ig_size = (T)framework::product(IG->dims());
+
+    EigenVector<T>::Flatten(*IG).device(*(context.GetEigenDevice<Place>())) =
+        EigenScalar<T>::From(*OG) / ig_size;
   }
 };
 
diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu
index c27fc886ce7238a13c8ef86bce673a2b54949a9d..dc9236701627dc9335b844d2a82e18eb1f7dfd42 100644
--- a/paddle/operators/mul_op.cu
+++ b/paddle/operators/mul_op.cu
@@ -12,6 +12,7 @@
    See the License for the specific language governing permissions and
    limitations under the License. */
 
+#define EIGEN_USE_GPU
 #include "paddle/operators/mul_op.h"
 
 REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
\ No newline at end of file
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index eef72ab293e13a9d05ce0013be41ec4bb75d6077..c7b78ad39045d25d73bfc2c930063c255a514864 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -26,13 +26,18 @@ public:
     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
         {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
 
+    auto input0 = context.Input<Tensor>("X");
+    auto input1 = context.Input<Tensor>("Y");
     auto output = context.Output<Tensor>(0);
+
     output->mutable_data<T>(context.GetPlace());
 
-    EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
-        EigenMatrix<T>::From(*context.Input<Tensor>("X"))
-            .contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
-                      dim_pair);
+    auto X = EigenMatrix<T>::From(*input0);
+    auto Y = EigenMatrix<T>::From(*input1);
+    auto Z = EigenMatrix<T>::From(*output);
+    auto place = context.GetEigenDevice<Place>();
+
+    Z.device(place) = X.contract(Y, dim_pair);
   }
 };
 }  // namespace operators
diff --git a/paddle/operators/recurrent_network_op.cc b/paddle/operators/recurrent_op.cc
similarity index 67%
rename from paddle/operators/recurrent_network_op.cc
rename to paddle/operators/recurrent_op.cc
index 60d065fc4789f76370840328870165579aa73b67..e5b76e3724b5b0287071c90d26235b8e1a1d80cf 100644
--- a/paddle/operators/recurrent_network_op.cc
+++ b/paddle/operators/recurrent_op.cc
@@ -12,7 +12,7 @@
    See the License for the specific language governing permissions and
    limitations under the License. */
 
-#include "paddle/operators/recurrent_network_op.h"
+#include "paddle/operators/recurrent_op.h"
 
 #include <glog/logging.h>
 #include <cstring>
@@ -29,11 +29,15 @@ namespace rnn {
 
 void SegmentInputs(const std::vector<Scope*>& step_scopes,
                    const std::vector<Link>& inlinks,
-                   const size_t seq_len) {
+                   const size_t seq_len,
+                   bool infer_shape_mode) {
   PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
   for (size_t i = 0; i < inlinks.size(); ++i) {
-    Tensor* input =
-        step_scopes[0]->FindVar(inlinks[i].external)->GetMutable<Tensor>();
+    auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
+    PADDLE_ENFORCE(input_var != nullptr,
+                   "input link [%s] is not in scope.",
+                   inlinks[i].external);
+    Tensor* input = input_var->GetMutable<Tensor>();
     DDim dims = input->dims();
     PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
                    "all the inlinks must have same length");
@@ -41,7 +45,9 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
     for (size_t j = 0; j < seq_len; j++) {
       Tensor* step_input =
           step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
-      *step_input = input->Slice<float>(j, j + 1);
+      if (!infer_shape_mode) {
+        *step_input = input->Slice<float>(j, j + 1);
+      }
       step_input->Resize(step_dims);
     }
   }
@@ -49,36 +55,41 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
 
 void ConcatOutputs(const std::vector<Scope*>& step_scopes,
                    const std::vector<Link>& outlinks,
-                   const size_t seq_len) {
+                   const size_t seq_len,
+                   bool infer_shape_mode) {
   for (size_t i = 0; i < outlinks.size(); i++) {
-    Tensor* output =
-        step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
-
-    // TODO(qingiqng) remove following code after adding
-    // InferShape in RecurrentGradientOp
-    DDim step_dims = step_scopes[0]
-                         ->FindVar(outlinks[i].internal)
-                         ->GetMutable<Tensor>()
-                         ->dims();
-    std::vector<int> dims_vec = vectorize(step_dims);
-    dims_vec.insert(dims_vec.begin(), seq_len);
-    output->mutable_data<float>(make_ddim(dims_vec), platform::CPUPlace());
-
-    for (size_t j = 0; j < seq_len; j++) {
-      Tensor* step_output =
-          step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
-      // TODO(luotao02) data type and platform::DeviceContext() should set
-      // correctly
-      (output->Slice<float>(j, j + 1))
-          .CopyFrom<float>(*step_output, platform::CPUPlace());
+    auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
+    PADDLE_ENFORCE(output_var != nullptr,
+                   "output link [%s] is not in scope.",
+                   outlinks[i].external);
+    Tensor* output = output_var->GetMutable<Tensor>();
+    if (infer_shape_mode) {
+      DDim step_dims = step_scopes[0]
+                           ->FindVar(outlinks[i].internal)
+                           ->GetMutable<Tensor>()
+                           ->dims();
+      std::vector<int> dims_vec = vectorize(step_dims);
+      dims_vec.insert(dims_vec.begin(), seq_len);
+      output->Resize(make_ddim(dims_vec));
+    } else {
+      output->mutable_data<float>(platform::CPUPlace());
+      for (size_t j = 0; j < seq_len; j++) {
+        Tensor* step_output =
+            step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
+        // TODO(luotao02) data type and platform::DeviceContext() should set
+        // correctly
+        (output->Slice<float>(j, j + 1))
+            .CopyFrom<float>(*step_output, platform::CPUPlace());
+      }
     }
   }
 }
 
 void LinkMemories(const std::vector<Scope*>& scopes,
                   const std::vector<rnn::MemoryAttr>& memories,
-                  size_t step_id,
-                  int offset) {
+                  const size_t step_id,
+                  const int offset,
+                  bool infer_shape_mode) {
   PADDLE_ENFORCE(step_id < scopes.size(),
                  "step [%d] is out of range of step scopes' size [%d]",
                  step_id,
@@ -95,18 +106,13 @@ void LinkMemories(const std::vector<Scope*>& scopes,
   auto scope = scopes[step_id];
   auto linked_scope = scopes[step_id + offset];
   for (auto& attr : memories) {
-    auto mem = scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
-    // maybe share variable is better?
+    auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>();
     auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>();
-    mem->ShareDataWith<float>(*linked_mem);
-
-    // TODO(qingqing) remove following code
-    // the memory of current step should be allocated in step net
-    auto m = scope->NewVar(attr.var)->GetMutable<Tensor>();
-    // for unit test, as addOp and mulOp are null currently, if not
-    // mutable_data, mem.data() in output will be error. We will
-    // remove this line after merge the correct addOp and mulOp.
-    m->mutable_data<float>(mem->dims(), platform::CPUPlace());
+    if (infer_shape_mode) {
+      mem->Resize(linked_mem->dims());
+    } else {
+      mem->ShareDataWith<float>(*linked_mem);
+    }
   }
 }
 
@@ -175,60 +181,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
                  ->dims()[0];
   CreateScopes(scope);
   auto step_scopes = GetStepScopes(scope);
-
-  // SegmentInputs is called in InferShape. The input must hold memory in
-  // SegmentInputs. But the other op only set dimension for the output in
-  // InferShape. That's a problem. Wether the RNN op needs InferShape or not?
-  // Wether the following functions (SegmentInputs, InitMemories, ...) need
-  // to rewrite for RNN op?
-  rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
-
-  InitMemories(step_scopes[0]);
-
-  PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
-                 "stepnet [%s] is not in scope.",
-                 arg_->step_net);
+  rnn::SegmentInputs(
+      step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
+  InitMemories(step_scopes[0], true /*infer_shape_mode*/);
   Variable* net = scope.FindVar(arg_->step_net);
   PADDLE_ENFORCE(net != nullptr, "failed to get step net");
-  // If the InferShape is called in OperatorBase's run function,
-  // the rnn op only needs to do InferShape for the first time step
   for (size_t i = 0; i < seq_len_; i++) {
     if (i > 0) {
-      rnn::LinkMemories(step_scopes, arg_->memories, i, -1);
+      rnn::LinkMemories(
+          step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/);
     }
     net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
   }
-
-  auto outlinks = arg_->outlinks;
-  for (size_t i = 0; i < outlinks.size(); i++) {
-    DDim step_dims = step_scopes[0]
-                         ->FindVar(outlinks[i].internal)
-                         ->GetMutable<Tensor>()
-                         ->dims();
-    std::vector<int> dims_vec = vectorize(step_dims);
-    // now only support fixed length
-    dims_vec.insert(dims_vec.begin(), seq_len_);
-    Tensor* output =
-        step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
-    output->Resize(make_ddim(dims_vec));
-  }
+  rnn::ConcatOutputs(
+      step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
 }
 
 void RecurrentAlgorithm::Run(const Scope& scope,
                              const platform::DeviceContext& dev_ctx) const {
   auto step_scopes = GetStepScopes(scope);
-
+  rnn::SegmentInputs(
+      step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
+  InitMemories(step_scopes[0], false /*infer_shape_mode*/);
   Variable* net = scope.FindVar(arg_->step_net);
+
   for (size_t step_id = 0; step_id < seq_len_; step_id++) {
-    // the link memory is done in InferShape
-    // maybe remove following code after testing
     if (step_id > 0) {
-      rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
+      rnn::LinkMemories(
+          step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/);
     }
     net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
   }
-
-  rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
+  rnn::ConcatOutputs(
+      step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
 }
 
 void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
@@ -244,18 +229,19 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
       // Now all variables in scope must be created outside of op.
       auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>();
       for (auto& input : net_op->inputs_) {
+        // the weight are located in parent scope
         if (!step_scope.FindVar(input)) step_scope.NewVar(input);
       }
       for (auto& output : net_op->outputs_) {
         step_scope.NewVar(output);
       }
-
       step_scopes->emplace_back(&step_scope);
     }
   }
 }
 
-void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
+void RecurrentAlgorithm::InitMemories(Scope* step_scope,
+                                      bool infer_shape_mode) const {
   for (auto& attr : arg_->memories) {
     Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
     PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
@@ -263,13 +249,11 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
                    attr.var,
                    attr.boot_var);
     Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>();
-    pre_mem->ShareDataWith<float>(*boot_mem);
-
-    // TODO(qingqing) remove following code
-    // the memory of current step should be allocated in step net
-    // here for unit test
-    auto cur_step_mem = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
-    cur_step_mem->mutable_data<float>(boot_mem->dims(), platform::CPUPlace());
+    if (infer_shape_mode) {
+      pre_mem->Resize(boot_mem->dims());
+    } else {
+      pre_mem->ShareDataWith<float>(*boot_mem);
+    }
   }
 }
 
@@ -307,13 +291,14 @@ public:
       : OpProtoAndCheckerMaker(proto, op_checker) {
     const auto& name = RecurrentOp::kArgName;
     // inputs and outputs stored in proto
-    AddInput(name.inlinks, "the input that need to be segmented for each step.")
+    AddInput(name.inlinks,
+             "the inputs that need to be segmented for each step.")
         .SetMultiple();
     AddInput(name.boot_memories, "variables to initialize memories.")
         .SetMultiple();
     AddInput(name.step_net, "network shared by all steps.");
 
-    AddOutput(name.outlinks, "the output that need to concated for all steps.")
+    AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
         .SetMultiple();
     AddOutput(name.step_scopes, "step scopes");
 
@@ -331,34 +316,39 @@ public:
 void RecurrentGradientAlgorithm::Run(
     const Scope& scope, const platform::DeviceContext& dev_ctx) const {
   auto step_scopes = GetStepScopes(scope);
-  rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
-  PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
-                 "step net is not in scope.");
+  rnn::SegmentInputs(
+      step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
   Variable* net = scope.FindVar(arg_->step_net);
   PADDLE_ENFORCE(net != nullptr, "failed to get step net");
   for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
     if (static_cast<size_t>(step_id) != seq_len_ - 1) {
-      rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
+      rnn::LinkMemories(
+          step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/);
     }
     net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
   }
-  LinkBootMemoryGradients(step_scopes[0]);
-  rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
+  LinkBootMemoryGradients(step_scopes[0], false);
+  rnn::ConcatOutputs(
+      step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
 }
 
 void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
-    Scope* step_scope) const {
+    Scope* step_scope, bool infer_shape_mode) const {
   for (auto& attr : arg_->memories) {
-    Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
-    PADDLE_ENFORCE(mem_grad != nullptr,
-                   "boot_tensor should be retrieved before");
+    PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
+                   "memory variable [%s] does not exists",
+                   attr.var);
     PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
-                   "memory [%s]'s boot variable [%s] not exists",
-                   attr.var,
+                   "boot variable [%s] does not exists",
                    attr.boot_var);
+    Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
     Tensor* boot_mem_grad =
         step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>();
-    boot_mem_grad->ShareDataWith<float>(*mem_grad);
+    if (infer_shape_mode) {
+      boot_mem_grad->Resize(mem_grad->dims());
+    } else {
+      boot_mem_grad->ShareDataWith<float>(*mem_grad);
+    }
   }
 }
 
@@ -367,34 +357,20 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
                  ->GetMutable<Tensor>()
                  ->dims()[0];
   auto step_scopes = GetStepScopes(scope);
-  rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
-
-  PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
-                 "step net is not in scope.");
+  rnn::SegmentInputs(
+      step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
   Variable* net = scope.FindVar(arg_->step_net);
   PADDLE_ENFORCE(net != nullptr, "failed to get step net");
-
   for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
     if (static_cast<size_t>(step_id) != seq_len_ - 1) {
-      rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
+      rnn::LinkMemories(
+          step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/);
     }
     net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
   }
-
-  auto outlinks = arg_->outlinks;
-  for (size_t i = 0; i < outlinks.size(); i++) {
-    DDim step_dims = step_scopes[0]
-                         ->FindVar(outlinks[i].internal)
-                         ->GetMutable<Tensor>()
-                         ->dims();
-    std::vector<int> dims_vec = vectorize(step_dims);
-    // now only support fixed length
-    dims_vec.insert(dims_vec.begin(), seq_len_);
-    Tensor* output =
-        step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
-    output->Resize(make_ddim(dims_vec));
-  }
-  LinkBootMemoryGradients(step_scopes[0]);
+  rnn::ConcatOutputs(
+      step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
+  LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
 }
 
 void RecurrentGradientOp::Init() {
diff --git a/paddle/operators/recurrent_network_op.h b/paddle/operators/recurrent_op.h
similarity index 92%
rename from paddle/operators/recurrent_network_op.h
rename to paddle/operators/recurrent_op.h
index d57a1a2e51cbed22549ab6ebce79223e2d4e3bcf..2a0964fff326500b6215dd4afac63c75d64c4a06 100644
--- a/paddle/operators/recurrent_network_op.h
+++ b/paddle/operators/recurrent_op.h
@@ -72,19 +72,22 @@ struct ArgumentName {
  */
 void SegmentInputs(const std::vector<Scope*>& step_scopes,
                    const std::vector<Link>& inlinks,
-                   const size_t seq_len);
+                   const size_t seq_len,
+                   bool infer_shape_mode);
 
 /**
  * Process outputs of step nets and merge to variables.
  */
 void ConcatOutputs(const std::vector<Scope*>& step_scopes,
                    const std::vector<Link>& outlinks,
-                   const size_t seq_len);
+                   const size_t seq_len,
+                   bool infer_shape_mode);
 
 void LinkMemories(const std::vector<Scope*>& step_scopes,
                   const std::vector<MemoryAttr>& memories,
-                  size_t step_id,
-                  int offset);
+                  const size_t step_id,
+                  const int offset,
+                  bool infer_shape_mode);
 
 void InitArgument(const ArgumentName& name, Argument* arg);
 
@@ -122,7 +125,7 @@ protected:
     return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
   }
 
-  void InitMemories(Scope* step_scopes) const;
+  void InitMemories(Scope* step_scopes, bool infer_shape_mode) const;
 
 private:
   std::unique_ptr<rnn::Argument> arg_;
@@ -145,7 +148,7 @@ public:
 
   void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const;
 
-  void LinkBootMemoryGradients(Scope* step_scopes) const;
+  void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const;
 
   /**
    * InferShape must be called before Run.
diff --git a/paddle/operators/recurrent_network_op_test.cc b/paddle/operators/recurrent_op_test.cc
similarity index 90%
rename from paddle/operators/recurrent_network_op_test.cc
rename to paddle/operators/recurrent_op_test.cc
index b0e61fbee611744adb85b498b1c3540f059afc8c..91f2972ca49953fd7a627289fa37db32916d85cd 100644
--- a/paddle/operators/recurrent_network_op_test.cc
+++ b/paddle/operators/recurrent_op_test.cc
@@ -18,7 +18,7 @@
 #include "paddle/framework/op_registry.h"
 #include "paddle/framework/operator.h"
 #include "paddle/framework/tensor.h"
-#include "paddle/operators/recurrent_network_op.h"
+#include "paddle/operators/recurrent_op.h"
 
 namespace paddle {
 namespace operators {
@@ -55,7 +55,7 @@ protected:
     w->GetMutable<Tensor>()->mutable_data<float>(
         make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
 
-    for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) {
+    for (auto boot : std::vector<std::string>{"h_boot"}) {
       LOG(INFO) << "create global variable " << boot;
       Variable* h_boot = scope_.NewVar(boot);
       h_boot->GetMutable<Tensor>()->mutable_data<float>(
@@ -79,7 +79,6 @@ protected:
     op_desc.add_inputs("x0");
     op_desc.add_inputs("x1");
     // boot_memories 3
-    op_desc.add_inputs("x_boot");
     op_desc.add_inputs("h_boot");
     // step net 5
     op_desc.add_inputs("step_net");
@@ -91,7 +90,7 @@ protected:
     auto _input_format = std::vector<int>{
         0,  // in_link
         3,  // memories
-        5   // step_net
+        4   // step_net
     };
     auto input_format = op_desc.add_attrs();
     input_format->set_name("input_format");
@@ -129,12 +128,11 @@ protected:
       inlink_alias->add_strings(item);
     }
     // pre memories
-    for (const auto& item :
-         std::vector<std::string>{"rnn/x@pre", "rnn/h@pre"}) {
+    for (const auto& item : std::vector<std::string>{"rnn/h@pre"}) {
       pre_memories->add_strings(item);
     }
     // memories
-    for (const auto& item : std::vector<std::string>{"rnn/x", "rnn/h"}) {
+    for (const auto& item : std::vector<std::string>{"rnn/h"}) {
       memories->add_strings(item);
     }
     // output alias
@@ -151,14 +149,11 @@ protected:
     LOG(INFO) << "create variable step_net";
     Variable* var = scope_.NewVar("step_net");
     auto net = var->GetMutable<NetOp>();
-    // rnn/s is net's input or output?
-    net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"};
-    net->inputs_ = {"rnn/s", "rnn/h"};
     net->AddOp(
         OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));
 
     net->AddOp(
-        OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {}));
+        OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {}));
     net->CompleteAddOp();
   }
 
@@ -297,7 +292,10 @@ protected:
     inlink.internal = "rnn/x";
     auto step_scopes =
         scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
-    rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10);
+    rnn::SegmentInputs(*step_scopes,
+                       std::vector<rnn::Link>{inlink},
+                       10,
+                       true /*infer_shape_mode*/);
   }
 
   void LinkeMemories() {
@@ -311,7 +309,8 @@ protected:
     auto step_scopes =
         scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
     for (int i = 1; i < 10; ++i) {
-      rnn::LinkMemories(*step_scopes, memories, i, -1);
+      rnn::LinkMemories(
+          *step_scopes, memories, i, -1, true /*infer_shape_mode*/);
     }
   }
 
@@ -333,14 +332,14 @@ TEST(RecurrentOp, LinkMemories) {
   using namespace paddle::operators;
 
   // create and init step scopes
-  int len = 10;
+  size_t len = 10;
   std::vector<Scope*> step_scopes;
-  for (int i = 0; i < len; ++i) {
+  for (size_t i = 0; i < len; ++i) {
     auto scope = new Scope();
     scope->NewVar("pre_h");
     auto tensor = scope->NewVar("h")->GetMutable<Tensor>();
     float* data = tensor->mutable_data<float>({15, 20}, CPUPlace());
-    for (int j = 0; j < 15 * 20; ++j) {
+    for (size_t j = 0; j < 15 * 20; ++j) {
       data[j] = rand() * (1. / (double)RAND_MAX);
     }
     step_scopes.push_back(scope);
@@ -354,24 +353,24 @@ TEST(RecurrentOp, LinkMemories) {
   std::vector<rnn::MemoryAttr> memories;
   memories.push_back(mem_attr);
 
-  for (int i = 1; i < len; ++i) {
-    rnn::LinkMemories(step_scopes, memories, i, -1);
+  for (size_t i = 1; i < len; ++i) {
+    rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/);
   }
   // check
-  for (int i = 0; i < len - 1; ++i) {
+  for (size_t i = 0; i < len - 1; ++i) {
     const float* a =
         step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>();
     const float* b = step_scopes[i + 1]
                          ->FindVar("pre_h")
                          ->GetMutable<Tensor>()
                          ->data<float>();
-    for (size_t i = 0; i < 15 * 20; ++i) {
-      ASSERT_FLOAT_EQ(a[i], b[i]);
+    for (size_t j = 0; j < 15 * 20; ++j) {
+      ASSERT_FLOAT_EQ(a[j], b[j]);
     }
   }
 
   for (int i = len - 2; i >= 0; --i) {
-    rnn::LinkMemories(step_scopes, memories, i, 1);
+    rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/);
   }
   // check
   for (int i = len - 2; i >= 0; --i) {
@@ -379,8 +378,8 @@ TEST(RecurrentOp, LinkMemories) {
         step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>();
     const float* b =
         step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>();
-    for (size_t i = 0; i < 15 * 20; ++i) {
-      ASSERT_FLOAT_EQ(a[i], b[i]);
+    for (size_t j = 0; j < 15 * 20; ++j) {
+      ASSERT_FLOAT_EQ(a[j], b[j]);
     }
   }
 
@@ -391,9 +390,3 @@ TEST(RecurrentOp, LinkMemories) {
 
 USE_OP(add_two);
 USE_OP(mul);
-
-// int main() {
-//  //! TODO(yuyang18): Temporary disable this unit-test because implementation
-//  //! error.
-//  return 0;
-//}
\ No newline at end of file
diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu
index 4b33e38ebabe853e179fe70ef7fde0a80b9050e2..82338ceccc06653791b26472e18d804f62735649 100644
--- a/paddle/operators/rowwise_add_op.cu
+++ b/paddle/operators/rowwise_add_op.cu
@@ -1,3 +1,4 @@
+#define EIGEN_USE_GPU
 #include "paddle/operators/rowwise_add_op.h"
 
 REGISTER_OP_GPU_KERNEL(rowwise_add,
diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h
index b86dd5463436bf521f9939b1c421b39f11102769..bd4d1128955fb718d3a84dfd96d8c68d7196e9cc 100644
--- a/paddle/operators/rowwise_add_op.h
+++ b/paddle/operators/rowwise_add_op.h
@@ -33,7 +33,7 @@ public:
     const int rest_size = input.size() / bias_size;
     Eigen::DSizes<int, 1> one_d(input.size());
     Eigen::DSizes<int, 1> bcast(rest_size);
-    output.reshape(one_d).device(*(context.GetEigenDevice<Place>())) =
+    output.reshape(one_d).device(context.GetEigenDevice<Place>()) =
         input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
   }
 };
diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu
index f8f5b90cab460b4457cfb0a88bfc012bafe0fbc2..d79258cbf13c699cfb2afaee229cf96a3e377b5e 100644
--- a/paddle/operators/sgd_op.cu
+++ b/paddle/operators/sgd_op.cu
@@ -1,3 +1,4 @@
+#define EIGEN_USE_GPU
 #include "paddle/operators/sgd_op.h"
 
 REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
\ No newline at end of file
diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h
index af1dfdd756ceb9991bee6b85c3281c05f0fb5a9f..0c3a240f9a4a5fc7bc4898e82786810cee2f7010 100644
--- a/paddle/operators/sgd_op.h
+++ b/paddle/operators/sgd_op.h
@@ -29,8 +29,12 @@ public:
 
     param_out->mutable_data<T>(ctx.GetPlace());
 
-    EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
-        EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
+    auto p = EigenVector<T>::Flatten(*param);
+    auto g = EigenVector<T>::Flatten(*grad);
+    auto o = EigenVector<T>::Flatten(*param_out);
+    auto place = ctx.GetEigenDevice<Place>();
+
+    o.device(place) = p - lr * g;
   }
 };
 
diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu
index f679b20418f04eff4310efe4e121963ce5a235e0..c9d11a2e1f9dcc563765c9e8cc1bae6beff57f18 100644
--- a/paddle/operators/sigmoid_op.cu
+++ b/paddle/operators/sigmoid_op.cu
@@ -1,3 +1,4 @@
+#define EIGEN_USE_GPU
 #include "paddle/operators/sigmoid_op.h"
 
 REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h
index 3dd23a9ebc7ac0972d6ee07b9ac051d59e66f62f..1412e4398440c8e946d3ab434a50e978079637ab 100644
--- a/paddle/operators/sigmoid_op.h
+++ b/paddle/operators/sigmoid_op.h
@@ -27,9 +27,11 @@ public:
     auto output = context.Output<Tensor>(0);
     output->mutable_data<T>(context.GetPlace());
 
-    EigenVector<T>::Flatten(*output).device(
-        *(context.GetEigenDevice<Place>())) =
-        1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
+    auto X = EigenVector<T>::Flatten(*input);
+    auto Y = EigenVector<T>::Flatten(*output);
+    auto place = context.GetEigenDevice<Place>();
+
+    Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
   }
 };
 }  // namespace operators
diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu
index a1f6944a369fe5148ffcfeabf3bf7063dcbc2664..ddf8f6e913ccf450185f377f531bf978f69ed1fc 100644
--- a/paddle/operators/softmax_op.cu
+++ b/paddle/operators/softmax_op.cu
@@ -1,3 +1,4 @@
+#define EIGEN_USE_GPU
 #include "paddle/framework/op_registry.h"
 #include "paddle/operators/softmax_op.h"
 
diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h
index a5c19c5fc7c6f5909dbb355aff09bf15405b6957..75c5197697dada58e09f4cda41cea13af56e79a3 100644
--- a/paddle/operators/softmax_op.h
+++ b/paddle/operators/softmax_op.h
@@ -46,9 +46,9 @@ public:
                                .reshape(batch_by_one)
                                .broadcast(one_by_class));
 
-    softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
+    softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
 
-    softmax.device(*(context.GetEigenDevice<Place>())) =
+    softmax.device(context.GetEigenDevice<Place>()) =
         (softmax *
          softmax.sum(along_class)
              .inverse()
diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h
index 93b62cddc819e0d1fd48323e474a294ff0d327e1..9049ffda1da5408411687474c5ed0c76c2394623 100644
--- a/paddle/operators/type_alias.h
+++ b/paddle/operators/type_alias.h
@@ -51,6 +51,7 @@ using CPUPlace = platform::CPUPlace;
 using GPUPlace = platform::GPUPlace;
 using NetOp = framework::NetOp;
 using OpRegistry = framework::OpRegistry;
+using OperatorBase = framework::OperatorBase;
 }  // namespace operators
 }  // namespace paddle
 
diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h
index 26c8eb78e614a68ec9728aad727d8fe3e08547ae..60a42c777d1c2ebbc22fdb77b1100cc6fcf7ff35 100644
--- a/paddle/platform/enforce.h
+++ b/paddle/platform/enforce.h
@@ -144,12 +144,12 @@ inline void throw_on_error(T e) {
   throw_on_error(e, "");
 }
 
-#define PADDLE_THROW(...)                                      \
-  do {                                                         \
-    throw ::paddle::platform::EnforceNotMet(                   \
-        std::make_exception_ptr(                               \
-            std::runtime_error(string::Sprintf(__VA_ARGS__))), \
-        __FILE__, __LINE__);                                   \
+#define PADDLE_THROW(...)                                              \
+  do {                                                                 \
+    throw ::paddle::platform::EnforceNotMet(                           \
+        std::make_exception_ptr(                                       \
+            std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
+        __FILE__, __LINE__);                                           \
   } while (0)
 
 #define PADDLE_ENFORCE(...)                                             \
diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..29dd0ded0ac75893da7e244d92725cd5e285efce
--- /dev/null
+++ b/paddle/pybind/CMakeLists.txt
@@ -0,0 +1,9 @@
+cc_library(paddle_pybind SHARED
+    SRCS pybind.cc
+    DEPS pybind python backward
+	fc_op
+	sgd_op
+	add_op
+	mean_op
+	cross_entropy_op
+	recurrent_op)
diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh
index 3860facb099950a5287d3f6b89c3de38f588f568..69ae0ea2d72c199a8e17c0595693e5e0b2f79ee1 100644
--- a/paddle/scripts/docker/build.sh
+++ b/paddle/scripts/docker/build.sh
@@ -148,7 +148,7 @@ cat >> /paddle/build/Dockerfile <<EOF
 ADD *.deb /
 # run paddle version to install python packages first
 RUN apt-get update &&\
-    apt-get install -y python-pip && pip install -U pip && \
+    apt-get install -y wget python-pip && pip install -U pip && \
     dpkg -i /*.deb ; apt-get install -f -y && \
     apt-get clean -y && \
     rm -f /*.deb && \
diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto
index 3bee5b572ae42750332b69e28af980ae325532da..b50b73c7e169f3e8ae75322d9a0a3cad5072a9c7 100644
--- a/proto/ModelConfig.proto
+++ b/proto/ModelConfig.proto
@@ -298,6 +298,11 @@ message DetectionOutputConfig {
   optional uint32 width = 9 [default = 1];
 }
 
+message ClipConfig {
+  required double min = 1;
+  required double max = 2;
+}
+
 message LayerInputConfig {
   required string input_layer_name = 1;
   optional string input_parameter_name = 2;
@@ -318,6 +323,7 @@ message LayerInputConfig {
   optional RowConvConfig row_conv_conf = 15;
   optional MultiBoxLossConfig multibox_loss_conf = 16;
   optional DetectionOutputConfig detection_output_conf = 17;
+  optional ClipConfig clip_conf = 18;
 }
 
 message LayerConfig {
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index f71fefffb59d4a53dda092ff83a61d9eec4b601f..9ea69fc5e57636c22fb20d5d97de760b9cc3bcde 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -2198,6 +2198,20 @@ class RowConvLayer(LayerBase):
         self.create_input_parameter(0, psize, dims)
 
 
+@config_layer('clip')
+class ClipLayer(LayerBase):
+    def __init__(self, name, inputs, min, max, **xargs):
+        super(ClipLayer, self).__init__(name, 'clip', 0, inputs=inputs, **xargs)
+        config_assert(
+            len(self.inputs) == 1,
+            'ClipLayer must have one and only one input.')
+        config_assert(min < max, 'min must be less than max.')
+        input_layer = self.get_input_layer(0)
+        self.set_layer_size(input_layer.size)
+        self.config.inputs[0].clip_conf.min = min
+        self.config.inputs[0].clip_conf.max = max
+
+
 # key: cost type
 # value: cost class
 g_cost_map = {}
@@ -2754,6 +2768,16 @@ class SumToOneNormLayer(LayerBase):
         self.set_layer_size(input_layer0.size)
 
 
+@config_layer('row_l2_norm')
+class RowL2NormLayer(LayerBase):
+    def __init__(self, name, inputs, **xargs):
+        super(RowL2NormLayer, self).__init__(
+            name, 'row_l2_norm', 0, inputs=inputs, **xargs)
+        config_assert(len(self.inputs) == 1, 'RowL2NormLayer must have 1 input')
+        input_layer = self.get_input_layer(0)
+        self.set_layer_size(input_layer.size)
+
+
 @config_layer('cos_vm')
 class CosSimVecMatLayer(LayerBase):
     def __init__(self, name, size, inputs, cos_scale=1.0, device=None):
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index 965874ddf632a83d00065c2d40037930a6e604a8..ea5fdcc50f6abbc67fb61b7fd56c100d9f9811d0 100755
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -76,6 +76,7 @@ __all__ = [
     'trans_layer',
     'rotate_layer',
     'sum_to_one_norm_layer',
+    'row_l2_norm_layer',
     'get_output_layer',
     'LayerType',
     'context_projection',
@@ -128,6 +129,7 @@ __all__ = [
     'prelu_layer',
     'gated_unit_layer',
     'crop_layer',
+    'clip_layer',
     'slice_projection',
 ]
 
@@ -160,6 +162,7 @@ class LayerType(object):
     BATCH_NORM_LAYER = 'batch_norm'
     NORM_LAYER = 'norm'
     SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm'
+    ROW_L2_NORM_LAYER = 'row_l2_norm'
     ADDTO_LAYER = 'addto'
 
     CONCAT_LAYER = 'concat'
@@ -221,6 +224,7 @@ class LayerType(object):
 
     PRELU = 'prelu'
     CROP_LAYER = 'crop'
+    CLIP_LAYER = 'clip'
 
     @staticmethod
     def is_layer_type(type_name):
@@ -2889,6 +2893,42 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None):
         name, LayerType.SUM_TO_ONE_NORM_LAYER, parents=[input], size=input.size)
 
 
+@wrap_name_default()
+@layer_support()
+def row_l2_norm_layer(input, name=None, layer_attr=None):
+    """
+    A layer for L2-normalization in each row.
+
+    .. math::
+       out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}}
+
+    where the size of :math:`in` is (batchSize x dataDim) ,
+    and the size of :math:`out` is a (batchSize x dataDim) .
+
+    The example usage is:
+
+    .. code-block:: python
+
+       row_l2_norm_layer = row_l2_norm_layer(input=layer)
+
+    :param input: Input layer.
+    :type input: LayerOutput
+    :param name: Layer name.
+    :type name: basestring
+    :param layer_attr: extra layer attributes.
+    :type layer_attr: ExtraLayerAttribute.
+    :return: LayerOutput object.
+    :rtype: LayerOutput
+    """
+    Layer(
+        name=name,
+        type=LayerType.ROW_L2_NORM_LAYER,
+        inputs=[input.name],
+        **ExtraAttr.to_kwargs(layer_attr))
+    return LayerOutput(
+        name, LayerType.ROW_L2_NORM_LAYER, parents=[input], size=input.size)
+
+
 @wrap_name_default("addto")
 @wrap_act_default(act=LinearActivation())
 @wrap_bias_attr_default(has_bias=False)
@@ -6046,3 +6086,36 @@ def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
         layer_type=LayerType.CROP_LAYER,
         parents=input,
         size=l.config.size)
+
+
+@wrap_name_default("clip")
+def clip_layer(input, min, max, name=None):
+    """
+    A layer for clipping the input value by the threshold.
+
+    .. math::
+
+        out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right)
+
+    .. code-block:: python
+
+        clip = clip_layer(input=input_layer, min=-10, max=10)
+
+    :param name: The Layer Name.
+    :type name: basestring
+    :param input: The input layer.
+    :type input: LayerOutput.
+    :param min: The lower threshold for clipping.
+    :type min: double
+    :param max: The upper threshold for clipping.
+    :type max: double
+    :return: LayerOutput
+    """
+    Layer(
+        name=name,
+        type=LayerType.CLIP_LAYER,
+        inputs=[input.name],
+        min=min,
+        max=max)
+    return LayerOutput(
+        name, LayerType.CLIP_LAYER, parents=[input], size=input.size)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
index cdf9b2eab733adb173cf33cd6a93ef7b5abefc50..0ffa58bc1e2088f75e7cd25c7ecdffbe270825a4 100755
--- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
+++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
@@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
 test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
 test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
 test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
-test_recursive_topology test_gated_unit_layer)
+test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer)
 
 export whole_configs=(test_split_datasource)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_clip_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_clip_layer.protostr
new file mode 100644
index 0000000000000000000000000000000000000000..4b9578a0c050ef74f186485fec3f6c1f7a0f0814
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_clip_layer.protostr
@@ -0,0 +1,31 @@
+type: "nn"
+layers {
+  name: "input"
+  type: "data"
+  size: 300
+  active_type: ""
+}
+layers {
+  name: "__clip_0__"
+  type: "clip"
+  size: 300
+  active_type: ""
+  inputs {
+    input_layer_name: "input"
+    clip_conf {
+      min: -10
+      max: 10
+    }
+  }
+}
+input_layer_names: "input"
+output_layer_names: "__clip_0__"
+sub_models {
+  name: "root"
+  layer_names: "input"
+  layer_names: "__clip_0__"
+  input_layer_names: "input"
+  output_layer_names: "__clip_0__"
+  is_recurrent_layer_group: false
+}
+
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_l2_norm_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_l2_norm_layer.protostr
new file mode 100644
index 0000000000000000000000000000000000000000..c2786ff55c7023d856d739face5e747cc5fee870
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_l2_norm_layer.protostr
@@ -0,0 +1,27 @@
+type: "nn"
+layers {
+  name: "input"
+  type: "data"
+  size: 300
+  active_type: ""
+}
+layers {
+  name: "__row_l2_norm_layer_0__"
+  type: "row_l2_norm"
+  size: 300
+  active_type: ""
+  inputs {
+    input_layer_name: "input"
+  }
+}
+input_layer_names: "input"
+output_layer_names: "__row_l2_norm_layer_0__"
+sub_models {
+  name: "root"
+  layer_names: "input"
+  layer_names: "__row_l2_norm_layer_0__"
+  input_layer_names: "input"
+  output_layer_names: "__row_l2_norm_layer_0__"
+  is_recurrent_layer_group: false
+}
+
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_clip_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_clip_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f066fe1fb30877bf40bb6299d35546f7427989a5
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_clip_layer.py
@@ -0,0 +1,6 @@
+from paddle.trainer_config_helpers import *
+
+data = data_layer(name='input', size=300)
+clip = clip_layer(input=data, min=-10, max=10)
+
+outputs(clip)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_row_l2_norm_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_row_l2_norm_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac8badb26a40e96e75225e6f61aa536cd28e9098
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_row_l2_norm_layer.py
@@ -0,0 +1,6 @@
+from paddle.trainer_config_helpers import *
+
+data = data_layer(name='input', size=300)
+row_l2_norm = row_l2_norm_layer(input=data)
+
+outputs(row_l2_norm)
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index 540636a0e8100fbf97231bd548dbc1176b07daca..4619b0edc3dd7e253e01f7fee5e6a8641340d291 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -8,7 +8,6 @@ add_python_test(test_framework
     test_fc_op.py
     test_add_two_op.py
     test_sgd_op.py
-    test_cross_entropy_op.py
     test_mul_op.py
     test_mean_op.py
     test_sigmoid_op.py
diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py
index 99085c367221150c8386a24e8d90d58fd63894c4..98fae1b975ad6243b20e5c19ec6ff68d5536cd74 100644
--- a/python/paddle/v2/framework/tests/op_test_util.py
+++ b/python/paddle/v2/framework/tests/op_test_util.py
@@ -26,40 +26,45 @@ class OpTestMeta(type):
 
             scope = core.Scope()
             kwargs = dict()
+            places = []
+            places.append(core.CPUPlace())
+            if core.is_compile_gpu():
+                places.append(core.GPUPlace(0))
 
-            for in_name in func.all_input_args:
-                if hasattr(self, in_name):
-                    kwargs[in_name] = in_name
-                    var = scope.new_var(in_name).get_tensor()
-                    arr = getattr(self, in_name)
-                    var.set_dims(arr.shape)
-                    var.set(arr)
-                else:
-                    kwargs[in_name] = "@EMPTY@"
+            for place in places:
+                for in_name in func.all_input_args:
+                    if hasattr(self, in_name):
+                        kwargs[in_name] = in_name
+                        var = scope.new_var(in_name).get_tensor()
+                        arr = getattr(self, in_name)
+                        var.set_dims(arr.shape)
+                        var.set(arr, place)
+                    else:
+                        kwargs[in_name] = "@EMPTY@"
 
-            for out_name in func.all_output_args:
-                if hasattr(self, out_name):
-                    kwargs[out_name] = out_name
-                    scope.new_var(out_name).get_tensor()
+                for out_name in func.all_output_args:
+                    if hasattr(self, out_name):
+                        kwargs[out_name] = out_name
+                        scope.new_var(out_name).get_tensor()
 
-            for attr_name in func.all_attr_args:
-                if hasattr(self, attr_name):
-                    kwargs[attr_name] = getattr(self, attr_name)
+                for attr_name in func.all_attr_args:
+                    if hasattr(self, attr_name):
+                        kwargs[attr_name] = getattr(self, attr_name)
 
-            op = func(**kwargs)
+                op = func(**kwargs)
 
-            op.infer_shape(scope)
+                op.infer_shape(scope)
 
-            ctx = core.DeviceContext.cpu_context()
-            op.run(scope, ctx)
+                ctx = core.DeviceContext.create(place)
+                op.run(scope, ctx)
 
-            for out_name in func.all_output_args:
-                actual = numpy.array(scope.find_var(out_name).get_tensor())
-                expect = getattr(self, out_name)
-                # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
-                # has some diff, and could not pass unittest. So I set decimal 3 here.
-                # And I will check this in future.
-                numpy.testing.assert_almost_equal(actual, expect, decimal=3)
+                for out_name in func.all_output_args:
+                    actual = numpy.array(scope.find_var(out_name).get_tensor())
+                    expect = getattr(self, out_name)
+                    # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
+                    # has some diff, and could not pass unittest. So I set decimal 3 here.
+                    # And I will check this in future.
+                    numpy.testing.assert_almost_equal(actual, expect, decimal=3)
 
         obj.test_all = test_all
         return obj
diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py
index a06d7a78ecf838a49e5f2808d3686c6b92faa8ce..6e6643201bf361fce1bad7de10b2562f0525e00a 100644
--- a/python/paddle/v2/framework/tests/test_add_two_op.py
+++ b/python/paddle/v2/framework/tests/test_add_two_op.py
@@ -1,6 +1,10 @@
 import unittest
-from op_test_util import OpTestMeta
+
 import numpy
+import paddle.v2.framework.core as core
+import paddle.v2.framework.create_op_creation_methods as creation
+
+from op_test_util import OpTestMeta
 
 
 class TestAddOp(unittest.TestCase):
@@ -8,10 +12,19 @@ class TestAddOp(unittest.TestCase):
 
     def setUp(self):
         self.type = "add_two"
-        self.X = numpy.random.random((342, 345)).astype("float32")
-        self.Y = numpy.random.random((342, 345)).astype("float32")
+        self.X = numpy.random.random((102, 105)).astype("float32")
+        self.Y = numpy.random.random((102, 105)).astype("float32")
         self.Out = self.X + self.Y
 
 
+class TestAddGradOp(unittest.TestCase):
+    def test_add_grad(self):
+        op = creation.op_creations.add_two(X="X", Y="Y", Out="Out")
+        backward_op = core.Operator.backward(op, set())
+        self.assertEqual(backward_op.type(), "add_two_grad")
+        expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
+        self.assertEqual(expected, str(backward_op))
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py
index 43931aac406cd93beede008066aa1c0c00eba6ea..00dc4399aaf59e6382692c3a4356f89a7e79a0c5 100644
--- a/python/paddle/v2/framework/tests/test_fc_op.py
+++ b/python/paddle/v2/framework/tests/test_fc_op.py
@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation
 class TestFc(unittest.TestCase):
     def test_fc(self):
         scope = core.Scope()
+        place = core.CPUPlace()
         x = scope.new_var("X")
+
         x_tensor = x.get_tensor()
         x_tensor.set_dims([1000, 784])
-        x_tensor.alloc_float()
+        x_tensor.alloc_float(place)
 
         w = scope.new_var("W")
         w_tensor = w.get_tensor()
         w_tensor.set_dims([784, 100])
-        w_tensor.alloc_float()
+        w_tensor.alloc_float(place)
 
-        w_tensor.set(numpy.random.random((784, 100)).astype("float32"))
+        w_tensor.set(numpy.random.random((784, 100)).astype("float32"), place)
 
         # Set a real numpy array here.
         # x_tensor.set(numpy.array([]))
@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase):
         op.infer_shape(scope)
         self.assertEqual([1000, 100], tensor.shape())
 
-        ctx = core.DeviceContext.cpu_context()
+        ctx = core.DeviceContext.create(place)
 
         op.run(scope, ctx)
 
diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py
index 0a87e66cd03af1bf84be8ffe111e4a8c3a24d6dc..e1ac66d3a4d23d617f7c5a4d97d070b2660954c8 100644
--- a/python/paddle/v2/framework/tests/test_mul_op.py
+++ b/python/paddle/v2/framework/tests/test_mul_op.py
@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase):
 
     def setUp(self):
         self.type = "mul"
-        self.X = np.random.random((32, 784)).astype("float32")
-        self.Y = np.random.random((784, 100)).astype("float32")
+        self.X = np.random.random((32, 84)).astype("float32")
+        self.Y = np.random.random((84, 100)).astype("float32")
         self.Out = np.dot(self.X, self.Y)
 
 
diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
index ef1514983c03f822f84b85437d1cfe653b6a1a2e..04abc14ee198fe4e2307e009c696a2b40ec271b6 100644
--- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py
+++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase):
 
     def setUp(self):
         self.type = "rowwise_add"
-        self.X = np.random.random((32, 784)).astype("float32")
-        self.b = np.random.random(784).astype("float32")
+        self.X = np.random.random((32, 84)).astype("float32")
+        self.b = np.random.random(84).astype("float32")
         self.Out = np.add(self.X, self.b)
 
 
diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py
index 405d73b224fa153e50b4ec408a921f2bdaab46aa..ca03cc11abe2ceb31b33a87797aa752943dd2a7d 100644
--- a/python/paddle/v2/framework/tests/test_sgd_op.py
+++ b/python/paddle/v2/framework/tests/test_sgd_op.py
@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase):
 
     def setUp(self):
         self.type = "sgd"
-        self.param = numpy.random.random((342, 345)).astype("float32")
-        self.grad = numpy.random.random((342, 345)).astype("float32")
+        self.param = numpy.random.random((102, 105)).astype("float32")
+        self.grad = numpy.random.random((102, 105)).astype("float32")
         self.learning_rate = 0.1
         self.param_out = self.param - self.learning_rate * self.grad
 
diff --git a/python/paddle/v2/framework/tests/test_tensor.py b/python/paddle/v2/framework/tests/test_tensor.py
index 6d59863cea29832f648139e07a134050e22bfa21..1af39818a305215b45219b8c5f0a10630fd64279 100644
--- a/python/paddle/v2/framework/tests/test_tensor.py
+++ b/python/paddle/v2/framework/tests/test_tensor.py
@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase):
     def test_int_tensor(self):
         scope = core.Scope()
         var = scope.new_var("test_tensor")
+        place = core.CPUPlace()
+
         tensor = var.get_tensor()
 
         tensor.set_dims([1000, 784])
-        tensor.alloc_int()
-
+        tensor.alloc_int(place)
         tensor_array = numpy.array(tensor)
         self.assertEqual((1000, 784), tensor_array.shape)
         tensor_array[3, 9] = 1
         tensor_array[19, 11] = 2
-        tensor.set(tensor_array)
+        tensor.set(tensor_array, place)
 
         tensor_array_2 = numpy.array(tensor)
         self.assertEqual(1.0, tensor_array_2[3, 9])
@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase):
     def test_float_tensor(self):
         scope = core.Scope()
         var = scope.new_var("test_tensor")
+        place = core.CPUPlace()
+
         tensor = var.get_tensor()
 
         tensor.set_dims([1000, 784])
-        tensor.alloc_float()
+        tensor.alloc_float(place)
 
         tensor_array = numpy.array(tensor)
         self.assertEqual((1000, 784), tensor_array.shape)
         tensor_array[3, 9] = 1.0
         tensor_array[19, 11] = 2.0
-        tensor.set(tensor_array)
+        tensor.set(tensor_array, place)
 
         tensor_array_2 = numpy.array(tensor)
         self.assertAlmostEqual(1.0, tensor_array_2[3, 9])