diff --git a/CMakeLists.txt b/CMakeLists.txt
index dcff6b54cafce35846627e78cfcdac65fae7e686..2a6b0a20e441676c85c9ed8f8ad1a6e7abdf1ea8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,7 +13,6 @@
 # limitations under the License
 
 cmake_minimum_required(VERSION 3.0)
-
 set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
 set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
 set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR})
diff --git a/cmake/generic.cmake b/cmake/generic.cmake
index e42e75c12ab1e5133f5ecbdb90ef26e3f8df5133..534be0abe246ac70950d85ad05441825c8ca768a 100644
--- a/cmake/generic.cmake
+++ b/cmake/generic.cmake
@@ -290,8 +290,22 @@ function(go_library TARGET_NAME)
     set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
   endif()
 
-  # Add dummy code to support `make target_name` under Terminal Command
   set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
+
+  # This custom command will always run since it depends on a not
+  # existing file.
+  add_custom_command(
+    OUTPUT dummy_rebulid_${TARGET_NAME}
+    COMMAND cmake -E touch ${dummyfile}
+    )
+  # Create a custom target that depends on the custom command output
+  # file, so the custom command can be referenced as a dependency by
+  # `add_dependencies`.
+  add_custom_target(rebuild_${TARGET_NAME}
+    DEPENDS dummy_rebulid_${TARGET_NAME}
+    )
+
+  # Add dummy code to support `make target_name` under Terminal Command
   file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
   if (go_library_SHARED OR go_library_shared)
     add_library(${TARGET_NAME} SHARED ${dummyfile})
@@ -302,6 +316,12 @@ function(go_library TARGET_NAME)
     add_dependencies(${TARGET_NAME} ${go_library_DEPS})
   endif(go_library_DEPS)
 
+  # The "source file" of the library is `${dummyfile}` which never
+  # change, so the target will never rebuild. Make the target depends
+  # on the custom command that touches the library "source file", so
+  # rebuild will always happen.
+  add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME})
+
   set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}")
 
   file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index 74937b2b710412ad3e4c45b5147474220cd9f771..9e9491d983b3e2b5b4f70692bb9171abc3ee895d 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -1,24 +1,25 @@
 # ddim lib
-cc_library(enforce SRCS enforce.cc DEPS glog)
-cc_test(enforce_test SRCS enforce_test.cc DEPS enforce)
 cc_library(ddim SRCS ddim.cc DEPS eigen3)
 cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
 nv_test(dim_test SRCS dim_test.cu DEPS ddim)
-cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory)
+
+cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory)
 cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
+cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
+
 cc_test(variable_test SRCS variable_test.cc)
 cc_test(scope_test SRCS scope_test.cc)
+
 proto_library(attr_type SRCS attr_type.proto)
 proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
-cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
 proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
+cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
 cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
 
 cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
 cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
 
-# cc_library(fc_op SRCS fully_connected_op.cc DEPS operator)
-cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc enforce)
+cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
 cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
 
 py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
@@ -29,4 +30,4 @@ add_dependencies(framework_py_proto framework_py_proto_init)
 proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
 # cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
 cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
-cc_test(net_op_test SRCS net_op_test.cc DEPS net)
+cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h
index c0c33d81149ac2fc2a9a57d90931ef32375fe1d0..ea5614a45f3a77a851358aff80abbc276c9972ba 100644
--- a/paddle/framework/attr_checker.h
+++ b/paddle/framework/attr_checker.h
@@ -4,8 +4,9 @@
 #include <functional>
 #include <string>
 #include <unordered_map>
+#include <unordered_set>
 #include <vector>
-#include "paddle/framework/enforce.h"
+#include "paddle/platform/enforce.h"
 
 namespace paddle {
 namespace framework {
@@ -41,6 +42,35 @@ class DefaultValueSetter {
   T default_value_;
 };
 
+template <typename T>
+class EnumInContainer {
+ public:
+  explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
+  void operator()(T& val) const {
+    PADDLE_ENFORCE(container_.find(val) != container_.end(),
+                   "Value %s is not in enum container %s", val,
+                   ContainerDebugString());
+  }
+
+ private:
+  std::string ContainerDebugString() const {
+    std::ostringstream sout;
+    sout << "[";
+    size_t cnt = 0;
+    for (auto& v : container_) {
+      sout << v;
+      ++cnt;
+      if (cnt != container_.size()) {
+        sout << " ,";
+      }
+    }
+    sout << "]";
+    return sout.str();
+  }
+
+  std::unordered_set<T> container_;
+};
+
 // check whether a certain attribute fit its limits
 // an attribute can have more than one limits
 template <typename T>
@@ -50,6 +80,11 @@ class TypedAttrChecker {
  public:
   TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
 
+  TypedAttrChecker& InEnum(const std::unordered_set<T>& range) {
+    value_checkers_.push_back(EnumInContainer<T>(range));
+    return *this;
+  }
+
   TypedAttrChecker& LargerThan(const T& lower_bound) {
     value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
     return *this;
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index d2ef85afe55e640a17b8c957bac61d175e69ff3f..545c1dcc2a1682839d90194002fdbb748d85e808 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #include "paddle/framework/ddim.h"
-#include "paddle/framework/enforce.h"
+#include "paddle/platform/enforce.h"
 
 namespace paddle {
 namespace framework {
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index 070850375d1bd3a61b98184495c979573bf9542c..9fcc657edcd5459d0a42a64d708603a4bcd53cf0 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -19,7 +19,7 @@ limitations under the License. */
 #include <stdexcept>
 #include <vector>
 #include "paddle/framework/dim.h"
-#include "paddle/framework/enforce.h"
+#include "paddle/platform/enforce.h"
 #include "unsupported/Eigen/CXX11/Tensor"
 
 namespace paddle {
@@ -119,17 +119,6 @@ int arity(const DDim& ddim);
 
 std::ostream& operator<<(std::ostream&, const DDim&);
 
-template <int NDIMS>
-Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
-  int rank = arity(dims);
-  PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
-  Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
-  for (int d = 0; d < rank; d++) {
-    dsizes[d] = dims[d];
-  }
-  return dsizes;
-}
-
 }  // namespace framework
 }  // namespace paddle
 
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
new file mode 100644
index 0000000000000000000000000000000000000000..4ba4fd4d110330805faf2468bd406cb23c6f1b1c
--- /dev/null
+++ b/paddle/framework/eigen.h
@@ -0,0 +1,84 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "paddle/framework/tensor.h"
+#include "unsupported/Eigen/CXX11/Tensor"
+
+namespace paddle {
+namespace framework {
+
+// EigenDim converts paddle::platform::DDim into Eigen::DSizes.
+template <int D>
+struct EigenDim {
+  using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
+
+  static Type From(const DDim& dims) {
+    PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
+    Type ret;
+    for (int d = 0; d < arity(dims); d++) {
+      ret[d] = dims[d];
+    }
+    return ret;
+  }
+};
+
+// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor.
+template <typename T, size_t D, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+struct EigenTensor {
+  // TODO(qijun) Now, default type in unaligned, and we will make a benchmark on
+  // the speed of aligned and unaligned version in future.
+  using Type = Eigen::TensorMap<Eigen::Tensor<T, D, MajorType, IndexType>>;
+
+  using ConstType =
+      Eigen::TensorMap<Eigen::Tensor<const T, D, MajorType, IndexType>>;
+
+  static Type From(Tensor& tensor, DDim dims) {
+    return Type(tensor.data<T>(), EigenDim<D>::From(dims));
+  }
+
+  static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); }
+
+  static ConstType From(const Tensor& tensor, DDim dims) {
+    return ConstType(tensor.data<T>(), EigenDim<D>::From(dims));
+  }
+
+  static ConstType From(const Tensor& tensor) {
+    return From(tensor, tensor.dims_);
+  }
+};
+
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
+  // Flatten is to reshape a Tensor into a one dimension EigenVector
+  static typename EigenTensor<T, 1>::Type Flatten(Tensor& tensor) {
+    return EigenTensor<T, 1>::From(
+        tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
+  }
+
+  static typename EigenTensor<T, 1>::ConstType Flatten(const Tensor& tensor) {
+    return EigenTensor<T, 1>::From(
+        tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
+  }
+};
+
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenMatrix = EigenTensor<T, 2, MajorType, IndexType>;
+
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..a9fa728e49a0dcc781e520a22c1ee5f921c4c733
--- /dev/null
+++ b/paddle/framework/eigen_test.cc
@@ -0,0 +1,101 @@
+/*
+  Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+  Licensed under the Apache License, Version 2.0 (the "License");
+  you may not use this file except in compliance with the License.
+  You may obtain a copy of the License at
+  http://www.apache.org/licenses/LICENSE-2.0
+  Unless required by applicable law or agreed to in writing, software
+  distributed under the License is distributed on an "AS IS" BASIS,
+  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  See the License for the specific language governing permissions and
+  limitations under the License.
+*/
+
+#include "paddle/framework/eigen.h"
+#include <gtest/gtest.h>
+
+namespace paddle {
+namespace framework {
+
+TEST(EigenDim, From) {
+  EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3}));
+  ASSERT_EQ(1, ed[0]);
+  ASSERT_EQ(2, ed[1]);
+  ASSERT_EQ(3, ed[2]);
+}
+
+TEST(Eigen, Tensor) {
+  Tensor t;
+  float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
+  for (int i = 0; i < 1 * 2 * 3; i++) {
+    p[i] = static_cast<float>(i);
+  }
+
+  EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
+
+  ASSERT_EQ(1, et.dimension(0));
+  ASSERT_EQ(2, et.dimension(1));
+  ASSERT_EQ(3, et.dimension(2));
+
+  for (int i = 0; i < 1; i++) {
+    for (int j = 0; j < 2; j++) {
+      for (int k = 0; k < 3; k++) {
+        ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f);
+      }
+    }
+  }
+}
+
+TEST(Eigen, VectorFrom) {
+  Tensor t;
+  float* p = t.mutable_data<float>(make_ddim({6}), platform::CPUPlace());
+  for (int i = 0; i < 6; i++) {
+    p[i] = static_cast<float>(i);
+  }
+
+  EigenVector<float>::Type ev = EigenVector<float>::From(t);
+
+  ASSERT_EQ(6, ev.dimension(0));
+
+  for (int i = 0; i < 6; i++) {
+    ASSERT_NEAR(i, ev(i), 1e-6f);
+  }
+}
+
+TEST(Eigen, VectorFlatten) {
+  Tensor t;
+  float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
+  for (int i = 0; i < 1 * 2 * 3; i++) {
+    p[i] = static_cast<float>(i);
+  }
+
+  EigenVector<float>::Type ev = EigenVector<float>::Flatten(t);
+
+  ASSERT_EQ(1 * 2 * 3, ev.dimension(0));
+
+  for (int i = 0; i < 1 * 2 * 3; i++) {
+    ASSERT_NEAR(i, ev(i), 1e-6f);
+  }
+}
+
+TEST(Eigen, Matrix) {
+  Tensor t;
+  float* p = t.mutable_data<float>(make_ddim({2, 3}), platform::CPUPlace());
+  for (int i = 0; i < 2 * 3; i++) {
+    p[i] = static_cast<float>(i);
+  }
+
+  EigenMatrix<float>::Type em = EigenMatrix<float>::From(t);
+
+  ASSERT_EQ(2, em.dimension(0));
+  ASSERT_EQ(3, em.dimension(1));
+
+  for (int i = 0; i < 2; i++) {
+    for (int j = 0; j < 3; j++) {
+      ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f);
+    }
+  }
+}
+
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/framework/enforce.cc b/paddle/framework/enforce.cc
deleted file mode 100644
index 644930ff989bb8935f37642c117084f580379bd7..0000000000000000000000000000000000000000
--- a/paddle/framework/enforce.cc
+++ /dev/null
@@ -1,15 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-   http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License. */
-
-#include "paddle/framework/enforce.h"
diff --git a/paddle/framework/enforce.h b/paddle/framework/enforce.h
deleted file mode 100644
index ffce8148e9516a5720757c87685ff6bd2937977c..0000000000000000000000000000000000000000
--- a/paddle/framework/enforce.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-#include <glog/logging.h>
-#include <paddle/string/printf.h>
-#include <exception>
-#include <sstream>
-
-namespace paddle {
-namespace framework {
-
-/**
- * @brief Enforce exception. Inherits std::exception
- *
- * All enforce condition not met, will throw an EnforceNotMet exception.
- */
-class EnforceNotMet : public std::exception {
- public:
-  EnforceNotMet(const std::string& msg, const char* file, int fileline) {
-    std::ostringstream sout;
-    sout << msg << " at [" << file << ":" << fileline << "];";
-    all_msg_ = sout.str();
-  }
-
-  const char* what() const noexcept override { return all_msg_.c_str(); }
-
- private:
-  std::string all_msg_;
-};
-
-// From https://stackoverflow.com/questions/30130930/
-// __buildin_expect is in C++ 11 standard. Since the condition which enforced
-// should be true in most situation, it will make the compiler generate faster
-// code by adding `UNLIKELY` macro.
-#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
-
-/**
- * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
- * __LINE__
- *
- * This macro take __VA_ARGS__, user can pass any type if that type can
- * serialize to std::ostream
- */
-#define PADDLE_THROW(...)                                            \
-  do {                                                               \
-    throw ::paddle::framework::EnforceNotMet(                        \
-        ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
-  } while (0)
-
-/**
- * @brief Enforce a condition, otherwise throw an EnforceNotMet
- */
-#ifdef NDEBUG
-#define PADDLE_ENFORCE(condition, ...) \
-  do {                                 \
-    if (UNLIKELY(!(condition))) {      \
-      PADDLE_THROW(__VA_ARGS__);       \
-    }                                  \
-  } while (0)
-#else
-#define PADDLE_ENFORCE(condition, ...) \
-  CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__);
-#endif
-
-}  // namespace framework
-}  // namespace paddle
diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc
index c39c87fcd6dc7d8b78c8112b0f258774e2bf74d7..8902e2bcf1245f3cd711bfd4e76b7d4ab9ed4d31 100644
--- a/paddle/framework/net.cc
+++ b/paddle/framework/net.cc
@@ -21,10 +21,7 @@ namespace paddle {
 namespace framework {
 
 std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
-  // NetPtr->reset(new PlainNet);
-  // NetPtr grad_ops = new PlainNet;
-  std::shared_ptr<PlainNet> grad_ops;
-  grad_ops.reset(new PlainNet);
+  auto grad_ops = std::make_shared<PlainNet>();
   for (auto& op : ForwardOps->ops_) {
     auto op_grad = OpRegistry::CreateGradOp(op);
     grad_ops->AddOp(op_grad);
@@ -33,7 +30,9 @@ std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
   return grad_ops;
 }
 
-void PlainNet::CompleteAddOp() {
+void PlainNet::CompleteAddOp(bool calc) {
+  add_op_done_ = true;
+  if (!calc) return;
   std::unordered_set<std::string> input_set;
   std::unordered_set<std::string> output_set;
   std::unordered_set<std::string> temp_output;
@@ -66,7 +65,6 @@ void PlainNet::CompleteAddOp() {
   }
 
   attrs_["temporary_index"] = tmp_index;
-  add_op_done_ = true;
 }
 
 std::string PlainNet::DebugString() const {
diff --git a/paddle/framework/net.h b/paddle/framework/net.h
index 1103c8ef2b01697aa3a92402a3325a1a8e6c700b..60bfd3ef5e8a1cc48d7e583d5863a25609ce82b6 100644
--- a/paddle/framework/net.h
+++ b/paddle/framework/net.h
@@ -16,7 +16,6 @@ limitations under the License. */
 
 #include <paddle/framework/op_desc.pb.h>
 #include <paddle/framework/operator.h>
-#include "paddle/framework/net_proto.pb.h"
 #include "paddle/framework/op_proto.pb.h"
 #include "paddle/framework/op_registry.h"
 #include "paddle/framework/scope.h"
@@ -41,7 +40,7 @@ namespace framework {
 class Net : public OperatorBase {
  public:
   virtual void AddOp(const OperatorPtr& op) = 0;
-  virtual void CompleteAddOp() = 0;
+  virtual void CompleteAddOp(bool calc) = 0;
 };
 
 using NetPtr = std::shared_ptr<Net>;
@@ -86,7 +85,7 @@ class PlainNet : public Net {
     ops_.push_back(op);
   }
 
-  void CompleteAddOp() override;
+  void CompleteAddOp(bool calculate = true) override;
 
   std::string DebugString() const override;
 
diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc
index 18151c56d9acb3b10d5949f92b3e093d38c796e0..e62a9914dcffb1a12a2fced0d1dc8ba14aa5dbd6 100644
--- a/paddle/framework/net_op_test.cc
+++ b/paddle/framework/net_op_test.cc
@@ -2,7 +2,11 @@
 #include <paddle/framework/net.h>
 #include <paddle/framework/op_registry.h>
 #include <paddle/framework/operator.h>
-#include "paddle/framework/fully_connected_op.h"
+
+USE_OP(add_two);
+USE_OP(mul);
+USE_OP(sigmoid);
+USE_OP(softmax);
 
 namespace paddle {
 namespace framework {
@@ -62,22 +66,30 @@ TEST(OpKernel, all) {
   net->Run(scope, dev_ctx);
   ASSERT_EQ(2, infer_shape_cnt);
   ASSERT_EQ(2, run_cnt);
-
-  ASSERT_THROW(net->AddOp(op2), EnforceNotMet);
+  ASSERT_THROW(net->AddOp(op2), std::runtime_error);
 }
-
 TEST(AddBackwardOp, TestGradOp) {
   auto net = std::make_shared<PlainNet>();
   ASSERT_NE(net, nullptr);
-  auto op1 = std::make_shared<FCOp>();
-  op1->inputs_ = {"x", "w1", "b1"};
-  op1->outputs_ = {"y"};
-  net->AddOp(op1);
+  net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
+  net->AddOp(
+      framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
+  net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
   auto grad_ops = AddBackwardOp(net);
   for (auto& op : grad_ops->ops_) {
     op->DebugString();
   }
 }
 
+// TODO(zhihong): add fc grad without registering.
+// TEST(AddBackwardOp, TestNoGradOp) {
+//   auto net = std::make_shared<PlainNet>();
+//   ASSERT_NE(net, nullptr);
+//   net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
+//   {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
+//     op->DebugString();
+//   }
+// }
+
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h
index 4a197102d6e9937c341b8bfdf1afcc863d7ff6d8..0aa1eca837b9dcb51be3c540b0d33fcad08956cb 100644
--- a/paddle/framework/op_registry.h
+++ b/paddle/framework/op_registry.h
@@ -286,7 +286,13 @@ class OpRegistry {
   }
 
   static OperatorPtr CreateGradOp(OperatorPtr op) {
-    OperatorPtr grad_op(grad_creators().at(op->type_)());
+    auto it = grad_creators().find(op->type_);
+    if (it == grad_creators().end()) {
+      LOG(INFO) << op->type_ << "does not has gradient op";
+      return nullptr;
+    }
+    // OperatorPtr grad_op(grad_creators().at(op->type_)());
+    OperatorPtr grad_op(it->second());
     grad_op->type_ = op->type_;
 
     AssembleGradInOut(op, grad_op);
@@ -470,11 +476,11 @@ class GradOpRegisterHelper {
  */
 #define REGISTER_GRADIENT_OP(__op_type, __op_class)            \
   STATIC_ASSERT_GLOBAL_NAMESPACE(                              \
-      __reg_gradient_op_##__reg_op__##__op_type,               \
+      __reg_gradient_op__##__op_type,                          \
       "REGISTER_GRADIENT_OP must be in global namespace");     \
   static ::paddle::framework::GradOpRegisterHelper<__op_class> \
-      __op_register_##__op_type##__(#__op_type);               \
-  int __op_register_##__op_type##_handle__() { return 0; }
+      __op_gradient_register_##__op_type##__(#__op_type);      \
+  int __op_gradient_register_##__op_type##_handle__() { return 0; }
 
 /**
  * Macro to Register OperatorKernel.
diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc
index d3a51a361aa56b26b87d79057f6700bd87264ca4..32a7e88a894fb61a460443b7d593a6cf44bc98c5 100644
--- a/paddle/framework/op_registry_test.cc
+++ b/paddle/framework/op_registry_test.cc
@@ -91,7 +91,7 @@ TEST(OpRegistry, IllegalAttr) {
   try {
     paddle::framework::OperatorPtr op __attribute__((unused)) =
         paddle::framework::OpRegistry::CreateOp(op_desc);
-  } catch (paddle::framework::EnforceNotMet err) {
+  } catch (std::runtime_error& err) {
     caught = true;
     std::string msg = "larger_than check fail";
     const char* err_msg = err.what();
@@ -138,7 +138,7 @@ TEST(OpRegistry, CustomChecker) {
   try {
     paddle::framework::OperatorPtr op __attribute__((unused)) =
         paddle::framework::OpRegistry::CreateOp(op_desc);
-  } catch (paddle::framework::EnforceNotMet err) {
+  } catch (std::runtime_error& err) {
     caught = true;
     std::string msg = "Attribute 'test_attr' is required!";
     const char* err_msg = err.what();
@@ -157,7 +157,7 @@ TEST(OpRegistry, CustomChecker) {
   try {
     paddle::framework::OperatorPtr op __attribute__((unused)) =
         paddle::framework::OpRegistry::CreateOp(op_desc);
-  } catch (paddle::framework::EnforceNotMet err) {
+  } catch (std::runtime_error& err) {
     caught = true;
     std::string msg = "'test_attr' must be even!";
     const char* err_msg = err.what();
@@ -196,7 +196,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
   pd::OpProto op_proto;
   pd::OpAttrChecker op_checker;
   auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
-  ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
+  ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
 }
 
 class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
@@ -212,5 +212,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
   pd::OpProto op_proto;
   pd::OpAttrChecker op_checker;
   auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
-  ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
+  ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
 }
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 4f07350e59dea72431417876f41f172e51ea53f9..93c6fad5d3d9f3de100d30161e6e438eb43816a2 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -19,9 +19,8 @@ limitations under the License. */
 #include <memory>
 #include <typeindex>
 #include "paddle/framework/ddim.h"
-#include "paddle/framework/enforce.h"
-#include "paddle/framework/tensor_types.h"
 #include "paddle/memory/memory.h"
+#include "paddle/platform/enforce.h"
 #include "paddle/platform/place.h"
 #include "unsupported/Eigen/CXX11/Tensor"
 
@@ -35,6 +34,15 @@ struct CastToPyBufferImpl;
 namespace framework {
 
 class Tensor {
+  template <bool less, size_t i, typename... args>
+  friend struct paddle::pybind::details::CastToPyBufferImpl;
+
+  template <typename T, size_t D, int MajorType, typename IndexType>
+  friend struct EigenTensor;
+
+  template <typename T, int MajorType, typename IndexType>
+  friend struct EigenVector;
+
  public:
   Tensor() : offset_(0) {}
 
@@ -46,7 +54,7 @@ class Tensor {
   }
 
   template <typename T>
-  T* raw_data() const {
+  T* data() {
     CheckDims<T>();
     return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
                                 offset_);
@@ -71,14 +79,14 @@ class Tensor {
         holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
             boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T)));
       } else if (platform::is_gpu_place(place)) {
-#ifdef __CUDACC__
+#ifdef PADDLE_ONLY_CPU
+        PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
+#else
         holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
             boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T)));
-#else
-        PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device.");
 #endif
       } else {
-        PADDLE_ENFORCE(true, "Unknown 'place'.");
+        PADDLE_THROW("Unknown 'place'.");
       }
       offset_ = 0;
     }
@@ -86,66 +94,6 @@ class Tensor {
                                 offset_);
   }
 
-  template <typename T, size_t NDIMS>
-  typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
-    Eigen::array<Eigen::DenseIndex, NDIMS> dims =
-        paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
-    return typename TTypes<T, NDIMS>::Tensor(raw_data<T>(), dims);
-  }
-
-  template <typename T, size_t NDIMS>
-  typename TTypes<T, NDIMS>::Tensor tensor() {
-    return typename TTypes<T, NDIMS>::Tensor(
-        raw_data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
-  }
-
-  // flat to rank = 1
-  template <typename T>
-  typename TTypes<T>::Flat flat() {
-    return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
-  }
-
-  // to TensorType Vec
-  template <typename T>
-  typename TTypes<T>::Vec vec() {
-    return tensor<T, 1>();
-  }
-
-  // to TensorType Matrix
-  template <typename T>
-  typename TTypes<T>::Matrix matrix() {
-    return tensor<T, 2>();
-  }
-
-  // const versions of all the methods above.
-  template <typename T, size_t NDIMS>
-  typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) const {
-    Eigen::array<Eigen::DenseIndex, NDIMS> dims =
-        paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
-    return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
-  }
-
-  template <typename T, size_t NDIMS>
-  typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
-    return typename TTypes<T, NDIMS>::Tensor(
-        data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
-  }
-
-  template <typename T>
-  typename TTypes<T>::ConstFlat flat() const {
-    return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
-  }
-
-  template <typename T>
-  typename TTypes<T>::ConstVec vec() const {
-    return tensor<T, 1>();
-  }
-
-  template <typename T>
-  typename TTypes<T>::ConstMatrix matrix() const {
-    return tensor<T, 2>();
-  }
-
   template <typename T>
   void ShareDataFrom(const Tensor& src) {
     src.CheckDims<T>();
@@ -251,8 +199,6 @@ class Tensor {
   std::shared_ptr<Placeholder> holder_;  // holds the memory block if allocated.
   DDim dims_;
   size_t offset_;  // marks the begin of tensor data area.
-  template <bool less, size_t i, typename... args>
-  friend struct paddle::pybind::details::CastToPyBufferImpl;
 };
 
 }  // namespace framework
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index 84c6f0cf6558819440458688ca52b06c1cf11dd0..8a7cbbd0de6fd6aaafa8649abb8628e971bc49c1 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
   bool caught = false;
   try {
     src_tensor.data<double>();
-  } catch (paddle::framework::EnforceNotMet err) {
+  } catch (std::runtime_error& err) {
     caught = true;
     std::string msg =
         "Tenosr holds no memory. Call Tensor::mutable_data first.";
@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataFrom) {
     bool caught = false;
     try {
       dst_tensor.ShareDataFrom<float>(src_tensor);
-    } catch (EnforceNotMet err) {
+    } catch (std::runtime_error& err) {
       caught = true;
       std::string msg =
           "Tenosr holds no memory. Call Tensor::mutable_data first.";
diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h
deleted file mode 100644
index 4bf27a377e828a56f9679e6698d314457d7caf0b..0000000000000000000000000000000000000000
--- a/paddle/framework/tensor_types.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-
-#include "unsupported/Eigen/CXX11/Tensor"
-
-namespace paddle {
-namespace framework {
-
-// Helper to define Tensor types given that the scalar is of type T.
-template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
-struct TTypes {
-  // Rank-<NDIMS> tensor of scalar type T.
-  typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
-                           Eigen::Aligned>
-      Tensor;
-  typedef Eigen::TensorMap<
-      Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
-      ConstTensor;
-
-  // Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
-  typedef Eigen::TensorMap<
-      Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
-      Eigen::Aligned>
-      Scalar;
-  typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
-                                                  Eigen::RowMajor, IndexType>,
-                           Eigen::Aligned>
-      ConstScalar;
-
-  // Rank-1 tensor (vector) of scalar type T.
-  typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
-                           Eigen::Aligned>
-      Flat;
-  typedef Eigen::TensorMap<
-      Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
-      ConstFlat;
-  typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
-                           Eigen::Aligned>
-      Vec;
-  typedef Eigen::TensorMap<
-      Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
-      ConstVec;
-
-  // Rank-2 tensor (matrix) of scalar type T.
-  typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
-                           Eigen::Aligned>
-      Matrix;
-  typedef Eigen::TensorMap<
-      Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
-      ConstMatrix;
-};
-
-}  // namespace framework
-}  // namespace paddle
diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt
index a5b14c0c71c18da1bb0b506c663f8680b1c3830a..2bec00cdb2d32d01a5a24e662bcca07f4154939c 100644
--- a/paddle/function/CMakeLists.txt
+++ b/paddle/function/CMakeLists.txt
@@ -36,6 +36,7 @@ if(WITH_GPU)
     add_simple_unittest(MulOpTest)
     add_simple_unittest(CosSimOpTest)
     add_simple_unittest(RowConvOpTest)
+    add_simple_unittest(CropOpTest)
 endif()
 
 add_simple_unittest(ConvOpTest)
diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f12ee43e3d72f9ac776eaff93914228850694dd2
--- /dev/null
+++ b/paddle/function/CropOp.cpp
@@ -0,0 +1,177 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "CropOp.h"
+#include "paddle/function/TensorShape.h"
+#include "paddle/math/Vector.h"
+
+namespace paddle {
+
+template <>
+void Crop<DEVICE_TYPE_CPU>(real* outputs,
+                           const real* inputs,
+                           const TensorShape inShape,
+                           const TensorShape outShape,
+                           const FuncConfig& conf) {
+  std::vector<uint32_t> crop_corner =
+      conf.get<std::vector<uint32_t>>("crop_corner");
+  int cCrop = crop_corner[1];
+  int hCrop = crop_corner[2];
+  int wCrop = crop_corner[3];
+
+  int num = inShape[0];
+  int inC = inShape[1];
+  int inH = inShape[2];
+  int inW = inShape[3];
+
+  int outC = outShape[1];
+  int outH = outShape[2];
+  int outW = outShape[3];
+
+  for (int n = 0; n < num; n++) {
+    for (int c = 0; c < outC; c++) {
+      for (int h = 0; h < outH; h++) {
+        int outoff = ((n * outC + c) * outH + h) * outW;
+        int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop;
+        memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real));
+      }
+    }
+  }
+}
+
+template <>
+void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
+                               real* outGrad,
+                               const TensorShape inShape,
+                               const TensorShape outShape,
+                               const FuncConfig& conf) {
+  std::vector<uint32_t> crop_corner =
+      conf.get<std::vector<uint32_t>>("crop_corner");
+  int cCrop = crop_corner[1];
+  int hCrop = crop_corner[2];
+  int wCrop = crop_corner[3];
+
+  int num = outShape[0];
+  int outC = outShape[1];
+  int outH = outShape[2];
+  int outW = outShape[3];
+
+  int inC = inShape[1];
+  int inH = inShape[2];
+  int inW = inShape[3];
+
+  for (int n = 0; n < num; n++) {
+    for (int c = 0; c < inC; c++) {
+      for (int h = 0; h < inH; h++) {
+        int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop;
+        int inoff = ((n * inC + c) * inH + h) * inW;
+        CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff));
+        CpuVector outG = CpuVector(inW, outGrad + outoff);
+        outG += inG;
+      }
+    }
+  }
+}
+
+/**
+ * \brief Crop input according to the specify corner and shape.
+ *        The input and output is a 4D tensor. In CropFunc, we only
+ *        crop the 2nd to 4th dimension.
+ *
+ * Argument in this Function:
+ * \param pad_    A struct object contains the cropping corner and shape.
+ * \param inputs  A 4D tensor, only one input.
+ * \param outputs A 4D tensor, the output value after cropping.
+ *
+ * For example,
+ * Input(2,2,2,3) = [
+ *                    [ [[1,2,3], [3,4,5]],
+ *                      [[2,3,5], [1,6,7]] ],
+ *                    [ [[4,3,1], [1,8,7]],
+ *                      [[3,8,9], [2,3,5]] ]
+ *                  ] # the input shape is (2,2,2,3)
+ *
+ * pad_: if corner = (0,1,1) and crop_shape = (2,1,2)
+ * Output(2,2,1,2) = [
+ *                    [ [[4,5]],
+ *                      [[6,7]] ],
+ *                    [ [[8,7]],
+ *                      [[3,5]] ]
+ *                  ] # the input shape is (2,2,2,3)
+ */
+template <DeviceType Device>
+class CropFunc : public FunctionBase {
+public:
+  void init(const FuncConfig& config) override { conf_ = config; }
+
+  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+    CHECK_EQ(1UL, inputs.size());
+    CHECK_EQ(1UL, outputs.size());
+    CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
+
+    TensorShape inShape = inputs[0].shape();
+    TensorShape outShape = outputs[0].shape();
+
+    Crop<Device>(outputs[0].data<real>(),
+                 inputs[0].data<real>(),
+                 inShape,
+                 outShape,
+                 conf_);
+  }
+
+private:
+  FuncConfig conf_;
+};
+
+/**
+ * \brief The backward propagation of cropping Function.
+ *
+ * Argument in this Function:
+ * \param crop_    The same meaning as it in CropFunc.
+ * \param inputs  The gradient with respect to the output value of CropFunc.
+ * \param outputs The gradient with respect to the input value of CropFunc.
+ */
+
+template <DeviceType Device>
+class CropGradFunc : public FunctionBase {
+public:
+  void init(const FuncConfig& config) override { conf_ = config; }
+
+  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+    CHECK_EQ(1UL, inputs.size());
+    CHECK_EQ(1UL, outputs.size());
+    CHECK_EQ(outputs[0].getArgType(), ADD_TO);
+
+    TensorShape outShape = outputs[0].shape();
+    TensorShape inShape = inputs[0].shape();
+
+    CropGrad<Device>(inputs[0].data<real>(),
+                     outputs[0].data<real>(),
+                     inShape,
+                     outShape,
+                     conf_);
+  }
+
+private:
+  FuncConfig conf_;
+};
+
+REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
+REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
+#ifndef PADDLE_ONLY_CPU
+REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
+REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
+#endif
+
+}  // namespace paddle
diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h
new file mode 100644
index 0000000000000000000000000000000000000000..87986fbdc7e33aeb24d947e82a5d67ba23f532de
--- /dev/null
+++ b/paddle/function/CropOp.h
@@ -0,0 +1,51 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Function.h"
+
+namespace paddle {
+
+/**
+ * \brief  This funtion crops inputs according to the specify start point and
+ *shape.
+ *
+ * \param[out] outputs	save results.
+ * \param[in]  inputs	input data.
+ * \param[in]  inShape  the shape of input tensor.
+ * \param[in]  conf     the cropping config
+ */
+template <DeviceType Device>
+void Crop(real* outputs,
+          const real* inputs,
+          const TensorShape inShape,
+          const TensorShape outShape,
+          const FuncConfig& conf);
+
+/**
+ * \brief   Cropping operation backward.
+ *
+ * \param[out] inGrad	gradients of previous layer
+ * \param[in]  outGrad  output gradient
+ * \param[in]  inShape  the shape of input tensor.
+ * \param[in]  conf     the cropping config
+ */
+template <DeviceType Device>
+void CropGrad(const real* inGrad,
+              real* outGrad,
+              const TensorShape inShape,
+              const TensorShape outShape,
+              const FuncConfig& conf);
+}  // namespace paddle
diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..37ce6de0647e5e06a231710b5a53089533de2407
--- /dev/null
+++ b/paddle/function/CropOpGpu.cu
@@ -0,0 +1,113 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "hl_base.h"
+#include "CropOp.h"
+
+namespace paddle {
+
+__global__ void KeCrop(real* outputs, const real* inputs,
+                      int inC, int inH, int inW,
+                      int cropC, int cropH, int cropW,
+                      int outC, int outH, int outW, int nthreads) {
+  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+  if (idx < nthreads) {
+    const int w = idx % outW;
+    const int h = (idx / outW) % outH;
+    const int c = (idx / outW / outH) % outC;
+    const int n = idx / outW / outH / outC;
+
+    const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w;
+    outputs[idx] = inputs[off];
+  }
+}
+
+template <>
+void Crop<DEVICE_TYPE_GPU>(real* outputs,
+                          const real* inputs,
+						  const TensorShape inShape,
+						  const TensorShape outShape,
+                          const FuncConfig& conf) {
+  std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
+  int cropC = crop_corner[1];
+  int cropH = crop_corner[2];
+  int cropW = crop_corner[3];
+
+  int num = inShape[0];
+  int inC = inShape[1];
+  int inH = inShape[2];
+  int inW = inShape[3];
+
+  int outC = outShape[1];
+  int outH = outShape[2];
+  int outW = outShape[3];
+
+  size_t nth = num * outC * outH * outW;
+  int blockSize = 1024;
+  int gridSize = (nth + blockSize - 1) / blockSize;
+
+  KeCrop<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
+    (outputs, inputs, inC, inH, inW, cropC, cropH, cropW,
+     outC, outH, outW, nth);
+  CHECK_SYNC("Crop");
+}
+
+__global__ void KeCropDiff(const real* inGrad, real* outGrad,
+                          int inC, int inH, int inW,
+                          int cropC, int cropH, int cropW,
+                          int outC, int outH, int outW, int nthreads) {
+  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+  if (idx < nthreads) {
+    const int w = idx % inW;
+    const int h = (idx / inW) % inH;
+    const int c = (idx / inW / inH) % inC;
+    const int n = idx / inW / inH / inC;
+
+    const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w;
+
+    outGrad[off] += inGrad[idx];
+  }
+}
+
+template <>
+void CropGrad<DEVICE_TYPE_GPU>(const real* inGrad,
+                              real* outGrad,
+                              const TensorShape inShape,
+                              const TensorShape outShape,
+                              const FuncConfig& conf) {
+  std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
+  int cropC = crop_corner[1];
+  int cropH = crop_corner[2];
+  int cropW = crop_corner[3];
+
+  int num = outShape[0];
+  int outC = outShape[1];
+  int outH = outShape[2];
+  int outW = outShape[3];
+
+  int inC = inShape[1];
+  int inH = inShape[2];
+  int inW = inShape[3];
+
+  size_t nth = num * inC * inH * inW;
+  int blockSize = 1024;
+  int gridSize = (nth + blockSize - 1) / blockSize;
+
+  KeCropDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
+    (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW,
+     outC, outH, outW, nth);
+  CHECK_SYNC("CropGrad");
+}
+
+}  // namespace paddle
diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6f11abfdf6f752857e0a75c62fb2b5c089c206d9
--- /dev/null
+++ b/paddle/function/CropOpTest.cpp
@@ -0,0 +1,49 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <gtest/gtest.h>
+#include "FunctionTest.h"
+
+namespace paddle {
+
+TEST(Crop, real) {
+  for (size_t numSamples : {5, 32}) {
+    for (size_t channels : {5, 5, 32}) {
+      for (size_t imgSizeH : {5, 33, 100}) {
+        for (size_t imgSizeW : {5, 32, 96}) {
+          VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
+                  << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
+          for (bool test_grad : {false, true}) {
+            CpuGpuFuncCompare compare(
+                test_grad ? "CropGrad" : "Crop",
+                FuncConfig()
+                    .set<std::vector<uint32_t>>("crop_corner", {0, 1, 1, 1})
+                    .set<std::vector<uint32_t>>("crop_shape", {0, 2, 3, 3}));
+            TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
+            TensorShape outDims{numSamples, 2, 3, 3};
+            compare.addInputs(
+                BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
+            compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT,
+                                         test_grad ? inDims : outDims,
+                                         test_grad ? ADD_TO : ASSIGN_TO),
+                               test_grad ? ADD_TO : ASSIGN_TO);
+            compare.run();
+          }
+        }
+      }
+    }
+  }
+}
+
+}  // namespace paddle
diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..69ad913420bdb6e1b2ed0618b7f9b78d7477be99
--- /dev/null
+++ b/paddle/gserver/layers/CropLayer.cpp
@@ -0,0 +1,146 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "CropLayer.h"
+#include "paddle/utils/Stat.h"
+namespace paddle {
+
+REGISTER_LAYER(crop, CropLayer);
+
+bool CropLayer::init(const LayerMap& layerMap,
+                     const ParameterMap& parameterMap) {
+  /* Initialize the basic parent class */
+  Layer::init(layerMap, parameterMap);
+  CHECK_LE(static_cast<int>(inputLayers_.size()), 2);
+  CHECK_GE(static_cast<int>(inputLayers_.size()), 1);
+  crop_axis_ = config_.axis();
+  for (int i = 0; i < config_.offset_size(); i++) {
+    crop_offsets_.push_back(config_.offset(i));
+  }
+
+  // 1. get input_0 shape
+  auto& input0_img_conf = config_.inputs(0).image_conf();
+  inDims_ = TensorShape({0,
+                         input0_img_conf.channels(),
+                         input0_img_conf.has_img_size_y()
+                             ? input0_img_conf.img_size_y()
+                             : input0_img_conf.img_size(),
+                         input0_img_conf.img_size()});
+  // 2. get target dims from config
+  if (config_.inputs_size() == 1) {
+    targetDims_ = TensorShape({config_.shape(0),
+                               config_.shape(1),
+                               config_.shape(2),
+                               config_.shape(3)});
+  } else {
+    // 2. get input_1 shape
+    auto& input1_img_conf = config_.inputs(1).image_conf();
+    targetDims_ = TensorShape({0,
+                               input1_img_conf.channels(),
+                               input1_img_conf.has_img_size_y()
+                                   ? input1_img_conf.img_size_y()
+                                   : input1_img_conf.img_size(),
+                               input1_img_conf.img_size()});
+  }
+
+  // 3. get final crop corner
+  int dimSize = 4;
+  crop_corner_ = {0, 0, 0, 0};
+  for (int i = 0; i < dimSize; i++) {
+    if (i >= crop_axis_) {
+      if (crop_offsets_.size() > 1) {
+        crop_corner_[i] = crop_offsets_[i - crop_axis_];
+      } else {
+        crop_corner_[i] = crop_offsets_[0];
+      }
+    }
+  }
+
+  outDims_ = TensorShape(4);
+
+  createFunction(
+      forward_, "Crop", FuncConfig().set("crop_corner", crop_corner_));
+  createFunction(
+      backward_, "CropGrad", FuncConfig().set("crop_corner", crop_corner_));
+
+  return true;
+}
+
+void CropLayer::setOutDims() {
+  MatrixPtr input = inputLayers_[1]->getOutputValue();
+  size_t batchSize = input->getHeight();
+  // get target dims from input_1
+  if (config_.inputs_size() == 2) {
+    targetDims_.setDim(0, batchSize);
+    int ch = config_.inputs(0).image_conf().channels();
+    if (ch != 0) targetDims_.setDim(1, ch);
+    int h = inputLayers_[1]->getOutput().getFrameHeight();
+    if (h != 0) targetDims_.setDim(2, h);
+    int w = inputLayers_[1]->getOutput().getFrameWidth();
+    if (w != 0) targetDims_.setDim(3, w);
+  }
+  // get final crop shape from target dims and crop axis
+  std::vector<uint32_t> crop_shape;
+  int dimSize = 4;
+  for (int i = 0; i < dimSize; i++) {
+    if (i >= crop_axis_) {
+      crop_shape.push_back(targetDims_[i]);
+    } else {
+      crop_shape.push_back(inDims_[i]);
+    }
+  }
+
+  outDims_.reshape(
+      {crop_shape[0], crop_shape[1], crop_shape[2], crop_shape[3]});
+  output_.setFrameHeight(crop_shape[2]);
+  output_.setFrameWidth(crop_shape[3]);
+}
+
+void CropLayer::setInDims() {
+  MatrixPtr input = inputLayers_[0]->getOutputValue();
+  size_t batchSize = input->getHeight();
+  inDims_.setDim(0, batchSize);
+  int h = inputLayers_[0]->getOutput().getFrameHeight();
+  if (h != 0) inDims_.setDim(2, h);
+  int w = inputLayers_[0]->getOutput().getFrameWidth();
+  if (w != 0) inDims_.setDim(3, w);
+}
+
+void CropLayer::forward(PassType passType) {
+  Layer::forward(passType);
+  setInDims();
+  setOutDims();
+  int size = outDims_[1] * outDims_[2] * outDims_[3];
+  resetOutput(outDims_[0], size);
+  MatrixPtr outV = getOutputValue();
+  REGISTER_TIMER_INFO("CropForward", getName().c_str());
+
+  BufferArgs inputs;
+  BufferArgs outputs;
+  inputs.addArg(*getInputValue(0), inDims_);
+  outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO);
+  forward_[0]->calc(inputs, outputs);
+}
+
+void CropLayer::backward(const UpdateCallback& callback) {
+  (void)callback;
+  REGISTER_TIMER_INFO("CropBackward", getName().c_str());
+
+  BufferArgs inputs;
+  BufferArgs outputs;
+  inputs.addArg(*getOutputGrad(), outDims_);
+  outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
+  backward_[0]->calc(inputs, outputs);
+}
+}  // namespace paddle
diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h
new file mode 100644
index 0000000000000000000000000000000000000000..6b6202621023575c1c83049ecbd019656c726e3f
--- /dev/null
+++ b/paddle/gserver/layers/CropLayer.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Layer.h"
+
+namespace paddle {
+
+/**
+ * \brief  This layer crop input according to the specify conf.
+ *         input_0: input to be cropped
+ *         input_1: optional reference input
+ *         axis: start dimension to be croped
+ *         offset: offset of cropping  in each dimension
+ *         shape: if reference input layer was not setted,
+ *                  crop input as this shape conf
+ */
+class CropLayer : public Layer {
+public:
+  explicit CropLayer(const LayerConfig& config) : Layer(config) {}
+
+  ~CropLayer() {}
+
+  bool init(const LayerMap& layerMap,
+            const ParameterMap& parameterMap) override;
+  void forward(PassType passType) override;
+  void backward(const UpdateCallback& callback = nullptr) override;
+
+protected:
+  void setOutDims();
+  void setInDims();
+
+  int32_t crop_axis_;
+  std::vector<uint32_t> crop_offsets_;
+  std::vector<uint32_t> crop_corner_;
+  TensorShape inDims_;
+  TensorShape targetDims_;
+  TensorShape outDims_;
+};
+}  // namespace paddle
diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp
index 4b92b5d163ad107c0783beae45f8c936112fcccf..d5621412caee843e24a0d0c9b7096402765738c7 100644
--- a/paddle/gserver/layers/Layer.cpp
+++ b/paddle/gserver/layers/Layer.cpp
@@ -359,12 +359,11 @@ void Layer::backwardActivation() {
   /* Do error clipping */
   if (config_.error_clipping_threshold() > 0.0f) {
     if (FLAGS_log_error_clipping) {
-      CpuVector outGradVec(0, nullptr);
-      outGradVec.subVecFrom(
-          output_.grad->getData(), 0, output_.grad->getElementCnt());
-      real maxAbsGrad = outGradVec.getAbsMax();
+      VectorPtr outGradVec = Vector::create(
+          output_.grad->getData(), output_.grad->getElementCnt(), useGpu_);
+      real maxAbsGrad = outGradVec->getAbsMax();
       if (maxAbsGrad > config_.error_clipping_threshold()) {
-        real avgAbsGrad = outGradVec.getAbsSum() / outGradVec.getSize();
+        real avgAbsGrad = outGradVec->getAbsSum() / outGradVec->getSize();
         LOG(INFO) << " layer=" << config_.name() << " need clipping,"
                   << " max error=" << maxAbsGrad << " avg error=" << avgAbsGrad;
       }
diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt
index 92f6cbcfe5a0e23c5939b1689a3e339367450387..a43adc7ce7db937bd62ea9bf1533b8a5899c259a 100644
--- a/paddle/gserver/tests/CMakeLists.txt
+++ b/paddle/gserver/tests/CMakeLists.txt
@@ -56,7 +56,7 @@ add_test(NAME test_DetectionOutput
 add_unittest_without_exec(test_ConvUnify
     test_ConvUnify.cpp
     LayerGradUtil.cpp)
-    
+
 add_test(NAME test_ConvUnify
     COMMAND test_ConvUnify)
 ################# test_BatchNorm #######################
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index 67251f08e34faff57d9e6fd6a1163ba655619a8b..9af083468c0f01218117211f9e4931ca0669e96a 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1802,6 +1802,34 @@ TEST(Layer, RowConvLayer) {
   }
 }
 
+TEST(Layer, CropLayer) {
+  TestConfig config;
+  // config input_0
+  config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0});
+  LayerInputConfig* input = config.layerConfig.add_inputs();
+  ImageConfig* img = input->mutable_image_conf();
+  img->set_channels(4);
+  img->set_img_size(16);
+  config.layerConfig.set_axis(2);
+  config.layerConfig.add_offset(0);
+  config.layerConfig.add_offset(0);
+
+  // config input_1
+  config.inputDefs.push_back({INPUT_DATA, "layer_1", 128, 0});
+  input = config.layerConfig.add_inputs();
+  img = input->mutable_image_conf();
+  img->set_channels(2);
+  img->set_img_size(8);
+
+  // config crop layer
+  config.layerConfig.set_type("crop");
+  config.layerConfig.set_name("cropLayer");
+
+  for (auto useGpu : {false, true}) {
+    testLayerGrad(config, "crop", 100, false, useGpu, false);
+  }
+}
+
 int main(int argc, char** argv) {
   testing::InitGoogleTest(&argc, argv);
   initMain(argc, argv);
diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc
index 1579174b1a6ff08824629d833d01411cff651f48..f61e67a32906083881dd7f47433521876be9b355 100644
--- a/paddle/memory/detail/system_allocator.cc
+++ b/paddle/memory/detail/system_allocator.cc
@@ -14,7 +14,7 @@ limitations under the License. */
 
 #include "paddle/memory/detail/system_allocator.h"
 #include "paddle/platform/assert.h"
-#include "paddle/platform/error.h"
+#include "paddle/platform/enforce.h"
 #include "paddle/platform/gpu_info.h"
 
 #include <stdlib.h>    // for malloc and free
@@ -128,8 +128,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {
   // process is terminating, in which case we don't care if
   // cudaFree succeeds.
   if (err != cudaErrorCudartUnloading) {
-    platform::throw_on_error(err,
-                             "cudaFree{Host} failed in GPUAllocator::Free.");
+    PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
   }
 }
 
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index f47c3a42083f289d6c99fe6df62e3478e0363e31..a37720e5093342f5e02bd9a15a3099de434d6396 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -27,7 +27,8 @@ function(op_library TARGET)
     endif()
 
     list(LENGTH cu_srcs cu_srcs_len)
-    if (${cu_srcs_len} EQUAL 0)
+    list(LENGTH op_library_DEPS dep_len)
+    if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0)
         message(WARNING "The op library ${TARGET} not support GPU!")
     endif()
 
@@ -47,3 +48,8 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu)
 op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
 op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
 op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
+
+op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
+        softmax_op net)
+
+op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc
index 41d044cdb72b5fb2a7f8654e8ad103778e0857d1..f59a027407d18c428c60b9e79eda7aa759d957d6 100644
--- a/paddle/operators/add_op.cc
+++ b/paddle/operators/add_op.cc
@@ -49,10 +49,25 @@ The equation is: Out = X + Y
 )DOC");
   }
 };
+
+class AddOpGrad : public framework::OperatorWithKernel {
+protected:
+  void InferShape(
+      const std::vector<const framework::Tensor *> &inputs,
+      const std::vector<framework::Tensor *> &outputs) const override {}
+  std::string DebugString() const override {
+    LOG(INFO) << "AddOpGrad";
+    return "";
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
 REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
+REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
+
 typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>
     AddKernel_CPU_float;
 REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float);
+// REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float);
diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h
index e08b3fb18775e2536a13bc838f40472c5c3e7ff7..39d54a63bd16cdafeec1cfcd86ef5d142382e880 100644
--- a/paddle/operators/add_op.h
+++ b/paddle/operators/add_op.h
@@ -14,6 +14,7 @@ limitations under the License. */
 
 #pragma once
 #include "glog/logging.h"
+#include "paddle/framework/eigen.h"
 #include "paddle/framework/operator.h"
 
 namespace paddle {
@@ -29,8 +30,10 @@ public:
 
     output->mutable_data<T>(context.GetPlace());
 
-    output->flat<T>().device(*(context.GetEigenDevice<Place>())) =
-        input0.flat<T>() + input1.flat<T>();
+    framework::EigenVector<T>::Flatten(*output).device(
+        *(context.GetEigenDevice<Place>())) =
+        framework::EigenVector<T>::Flatten(input0) +
+        framework::EigenVector<T>::Flatten(input1);
   }
 };
 
diff --git a/paddle/operators/add_op_test.cc b/paddle/operators/add_op_test.cc
index 53b354fedcacf2176aed8b504daf2046bdf96bb6..7fc1049893e171a17af92da7e813b2463874c9de 100644
--- a/paddle/operators/add_op_test.cc
+++ b/paddle/operators/add_op_test.cc
@@ -16,8 +16,13 @@ limitations under the License. */
 #define private public
 #include <paddle/framework/op_registry.h>
 USE_OP(add_two);
+// USE_OP(add_two_grad);
+
 TEST(AddOp, GetOpProto) {
   auto& protos = paddle::framework::OpRegistry::protos();
   auto it = protos.find("add_two");
   ASSERT_NE(it, protos.end());
-}
\ No newline at end of file
+  auto& grad_creators = paddle::framework::OpRegistry::grad_creators();
+  auto it1 = grad_creators.find("add_two");
+  ASSERT_NE(it1, grad_creators.end());
+}
diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..01e96f4c4817466e3266ca57a0d0ae2368b3e097
--- /dev/null
+++ b/paddle/operators/fc_op.cc
@@ -0,0 +1,76 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/framework/net.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/framework/operator.h"
+
+namespace paddle {
+namespace operators {
+
+class FullyConnectedOp : public framework::PlainNet {
+public:
+  void Init() override {
+    AddOp(framework::OpRegistry::CreateOp("mul",
+                                          {
+                                              Input("X"), Input("W"),
+                                          },
+                                          {Output("before_act")},
+                                          {}));
+    auto b = Input("b");
+    if (b != framework::OperatorBase::EMPTY_VAR_NAME()) {
+      AddOp(framework::OpRegistry::CreateOp("rowwise_add",
+                                            {Output("before_act"), Input("b")},
+                                            {Output("before_act")},
+                                            {}));
+    }
+
+    auto activation = GetAttr<std::string>("activation");
+    AddOp(framework::OpRegistry::CreateOp(
+        activation, {Output("before_act")}, {Output("Y")}, {}));
+    CompleteAddOp(false);
+  }
+};
+
+class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
+public:
+  FullyConnectedOpMaker(framework::OpProto *proto,
+                        framework::OpAttrChecker *op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X", "the input of fc operator");
+    AddInput("W", "the weight of fc operator");
+    AddInput("b", "the bias of fc operator");
+
+    AddOutput("Y", "the output of fc operator");
+    AddOutput(
+        "before_act", "the before activation output of fc operator", true);
+    AddAttr<std::string>("activation", "The activation key for fc layer")
+        .SetDefault("sigmoid")
+        .InEnum({"sigmoid", "softmax"});
+
+    //! TODO(yuyang18): Complete comment;
+    AddComment("FullyConnected Operator");
+  }
+};
+}  // namespace operators
+}  // namespace paddle
+
+USE_OP(mul);
+USE_OP(rowwise_add);
+USE_OP(sigmoid);
+USE_OP(softmax);
+
+REGISTER_OP(fc,
+            paddle::operators::FullyConnectedOp,
+            paddle::operators::FullyConnectedOpMaker);
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 713b2a5dc83d8dd5a3d944101591d75cb19fe04f..ebf345194c6d15fca20b21e35c42ad6c775255d8 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -52,9 +52,22 @@ The equation is: Out = X * Y
   }
 };
 
+class MulOpGrad : public framework::OperatorWithKernel {
+protected:
+  void InferShape(
+      const std::vector<const framework::Tensor *> &inputs,
+      const std::vector<framework::Tensor *> &outputs) const override {}
+  std::string DebugString() const override {
+    LOG(INFO) << "MulGrad";
+    return "";
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
 REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
+REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad);
+
 REGISTER_OP_CPU_KERNEL(
     mul, paddle::operators::MulKernel<paddle::platform::CPUPlace>);
diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..04df87a3add2af7daa127a072f7b690f6cf94327
--- /dev/null
+++ b/paddle/operators/sgd_op.cc
@@ -0,0 +1,61 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sgd_op.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/framework/tensor.h"
+
+namespace paddle {
+namespace operators {
+
+class SGDOp : public framework::OperatorWithKernel {
+protected:
+  void InferShape(
+      const std::vector<const framework::Tensor *> &inputs,
+      const std::vector<framework::Tensor *> &outputs) const override {
+    PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
+    PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
+    PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
+    PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set");
+    PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set");
+    PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
+                   "Two input of SGD Op's dimension must be same.");
+    outputs[0]->set_dims(inputs[0]->dims());
+  }
+};
+
+class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
+public:
+  SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("param", "input parameter");
+    AddInput("grad", "input gradient");
+    AddOutput("param_out", "output parameter");
+    AddAttr<float>("learning_rate", "learning rate of sgd");
+    AddComment(R"DOC(
+
+Simplest sgd algorithm.
+
+param_out = param - learning_rate * grad;
+
+)DOC");
+  }
+};
+}  // namespace operators
+}  // namespace paddle
+
+REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker);
+typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float>
+    SGDOpKernel_CPU_float;
+REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..400425db10896e3970fc7468e34aba596a536184
--- /dev/null
+++ b/paddle/operators/sgd_op.cu
@@ -0,0 +1,5 @@
+#include "paddle/operators/sgd_op.h"
+#include "paddle/framework/op_registry.h"
+
+typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float;
+REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float);
\ No newline at end of file
diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..4b2d214618e5c7c15695bd66604139d805255c47
--- /dev/null
+++ b/paddle/operators/sgd_op.h
@@ -0,0 +1,42 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "glog/logging.h"
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/operator.h"
+
+namespace paddle {
+namespace operators {
+
+template <typename Place, typename T>
+class SGDOpKernel : public framework::OpKernel {
+public:
+  void Compute(const framework::KernelContext& ctx) const override {
+    auto param = ctx.Input("param")->Get<framework::Tensor>();
+    auto grad = ctx.Input("grad")->Get<framework::Tensor>();
+    auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>();
+    float lr = ctx.op_.GetAttr<float>("learning_rate");
+
+    param_out->mutable_data<T>(ctx.GetPlace());
+
+    framework::EigenVector<T>::Flatten(*param_out)
+        .device(*(ctx.GetEigenDevice<Place>())) =
+        framework::EigenVector<T>::Flatten(param) -
+        lr * framework::EigenVector<T>::Flatten(grad);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/sgd_op_test.cc b/paddle/operators/sgd_op_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..75137259f5e608b259b073101353e5818bb17c92
--- /dev/null
+++ b/paddle/operators/sgd_op_test.cc
@@ -0,0 +1,22 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <gtest/gtest.h>
+#include <paddle/framework/op_registry.h>
+USE_OP(sgd);
+TEST(SGDOp, GetOpProto) {
+  auto& protos = paddle::framework::OpRegistry::protos();
+  auto it = protos.find("sgd");
+  ASSERT_NE(it, protos.end());
+}
diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc
index 45ae277c538ca90716febaf2f3d92b560149d147..16348db020064447b4445f0075c0591972ca3091 100644
--- a/paddle/operators/sigmoid_op.cc
+++ b/paddle/operators/sigmoid_op.cc
@@ -39,11 +39,24 @@ public:
   }
 };
 
+class SigmoidOpGrad : public framework::OperatorWithKernel {
+protected:
+  void InferShape(
+      const std::vector<const framework::Tensor *> &inputs,
+      const std::vector<framework::Tensor *> &outputs) const override {}
+  std::string DebugString() const override {
+    LOG(INFO) << "SigmoidGrad";
+    return "";
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
 REGISTER_OP(sigmoid,
             paddle::operators::SigmoidOp,
             paddle::operators::SigmoidOpMaker);
+REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad);
+
 REGISTER_OP_CPU_KERNEL(
     sigmoid, paddle::operators::SigmoidKernel<paddle::platform::CPUPlace>);
diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc
index 4ca7be359e210d7a31aef94e498f37a1ad4879a2..146326d28330a275238317b6129b3c252797c029 100644
--- a/paddle/operators/softmax_op.cc
+++ b/paddle/operators/softmax_op.cc
@@ -40,10 +40,23 @@ public:
   }
 };
 
+class SoftmaxOpGrad : public framework::OperatorWithKernel {
+protected:
+  void InferShape(
+      const std::vector<const framework::Tensor *> &inputs,
+      const std::vector<framework::Tensor *> &outputs) const override {}
+  std::string DebugString() const override {
+    LOG(INFO) << "SoftmaxOpGrad";
+    return "";
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
 namespace ops = paddle::operators;
 
 REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
+REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad);
+
 REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<paddle::platform::CPUPlace>);
diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt
index 6ac4035c0f863c5f63d17b523a7a8be668ff3da0..bd77bb7daa50e0b273f110624ddf6f4b79a3ceab 100644
--- a/paddle/platform/CMakeLists.txt
+++ b/paddle/platform/CMakeLists.txt
@@ -8,6 +8,8 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
 
 add_subdirectory(dynload)
 
+cc_test(enforce_test SRCS enforce_test.cc)
+
 IF(WITH_GPU)
     set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
 ELSE()
diff --git a/paddle/platform/cpu_info.cc b/paddle/platform/cpu_info.cc
index dfab391cfbe1f04bc2a998233f7e7909579ca72b..78e1fa9df56b1623bfd9a53c6a37524d29648afc 100644
--- a/paddle/platform/cpu_info.cc
+++ b/paddle/platform/cpu_info.cc
@@ -22,7 +22,6 @@ limitations under the License. */
 #endif
 
 #include "gflags/gflags.h"
-#include "paddle/platform/error.h"
 
 DEFINE_double(fraction_of_cpu_memory_to_use, 1,
               "Default use 100% of CPU memory for PaddlePaddle,"
diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h
index f226a75c20b7a75e5f884cd158d139ebb8b34e47..fe6f13e399a78f9e5230ae52b0f67ab465af373b 100644
--- a/paddle/platform/device_context.h
+++ b/paddle/platform/device_context.h
@@ -11,12 +11,13 @@ limitations under the License. */
 
 #pragma once
 
-#include "paddle/framework/enforce.h"
+#include "paddle/platform/enforce.h"
+#include "paddle/platform/place.h"
+
 #ifndef PADDLE_ONLY_CPU
 #include "paddle/platform/dynload/cublas.h"
 #include "paddle/platform/dynload/cudnn.h"
 #include "paddle/platform/dynload/curand.h"
-#include "paddle/platform/error.h"
 #include "paddle/platform/gpu_info.h"
 #define EIGEN_USE_GPU
 #endif
@@ -71,8 +72,7 @@ class CUDADeviceContext : public DeviceContext {
  public:
   explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
     GPUPlaceGuard guard(gpu_place_);
-    paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
-                                     "cudaStreamCreate failed");
+    PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
     eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
     eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
   }
@@ -83,8 +83,8 @@ class CUDADeviceContext : public DeviceContext {
   }
 
   void Wait() {
-    paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
-                                     "cudaStreamSynchronize failed");
+    PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
+                   "cudaStreamSynchronize failed");
   }
 
   cudaStream_t stream() { return stream_; }
@@ -94,12 +94,11 @@ class CUDADeviceContext : public DeviceContext {
   cublasHandle_t cublas_handle() {
     if (!blas_handle_) {
       GPUPlaceGuard guard(gpu_place_);
-      PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
-                         CUBLAS_STATUS_SUCCESS,
+      PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
                      "cublasCreate failed");
-      PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream(
-                         blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
-                     "cublasSetStream failed");
+      PADDLE_ENFORCE(
+          paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
+          "cublasSetStream failed");
     }
     return blas_handle_;
   }
@@ -107,12 +106,11 @@ class CUDADeviceContext : public DeviceContext {
   cudnnHandle_t cudnn_handle() {
     if (!dnn_handle_) {
       GPUPlaceGuard guard(gpu_place_);
-      PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
-                         CUDNN_STATUS_SUCCESS,
+      PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
                      "cudnnCreate failed");
-      PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream(
-                         dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
-                     "cudnnSetStream failed");
+      PADDLE_ENFORCE(
+          paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
+          "cudnnSetStream failed");
     }
     return dnn_handle_;
   }
@@ -121,16 +119,15 @@ class CUDADeviceContext : public DeviceContext {
     if (!rand_generator_) {
       GPUPlaceGuard guard(gpu_place_);
       PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
-                         &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
-                         CURAND_STATUS_SUCCESS,
+                         &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
                      "curandCreateGenerator failed");
       PADDLE_ENFORCE(
           paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
-              rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
+              rand_generator_, random_seed_),
           "curandSetPseudoRandomGeneratorSeed failed");
-      PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream(
-                         rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
-                     "curandSetStream failed");
+      PADDLE_ENFORCE(
+          paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
+          "curandSetStream failed");
     }
     return rand_generator_;
   }
@@ -138,26 +135,23 @@ class CUDADeviceContext : public DeviceContext {
   ~CUDADeviceContext() {
     Wait();
     if (blas_handle_) {
-      PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
-                         CUBLAS_STATUS_SUCCESS,
+      PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
                      "cublasDestroy failed");
     }
 
     if (dnn_handle_) {
-      PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) ==
-                         CUDNN_STATUS_SUCCESS,
+      PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
                      "cudnnDestroy failed");
     }
 
     if (rand_generator_) {
-      PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator(
-                         rand_generator_) == CURAND_STATUS_SUCCESS,
-                     "curandDestroyGenerator failed");
+      PADDLE_ENFORCE(
+          paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
+          "curandDestroyGenerator failed");
     }
     eigen_stream_.reset();
     eigen_device_.reset();
-    paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
-                                     "cudaStreamDestroy failed");
+    PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
   }
 
  private:
diff --git a/paddle/platform/dynload/dynamic_loader.cc b/paddle/platform/dynload/dynamic_loader.cc
index dd914e006d54c423ffea56ffaaafe7dcba416361..ae9a0a982c73de05821579d22b7f9ad99f24a92b 100644
--- a/paddle/platform/dynload/dynamic_loader.cc
+++ b/paddle/platform/dynload/dynamic_loader.cc
@@ -19,7 +19,7 @@ limitations under the License. */
 #include <string>
 #include "gflags/gflags.h"
 #include "glog/logging.h"
-#include "paddle/framework/enforce.h"
+#include "paddle/platform/enforce.h"
 
 DEFINE_string(cudnn_dir, "",
               "Specify path for loading libcudnn.so. For instance, "
diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h
new file mode 100644
index 0000000000000000000000000000000000000000..5d440dec48e7a4cba404bc297eca5a451a144d93
--- /dev/null
+++ b/paddle/platform/enforce.h
@@ -0,0 +1,141 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include <paddle/string/printf.h>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+
+#ifndef PADDLE_ONLY_CPU
+
+#include "paddle/platform/dynload/cublas.h"
+#include "paddle/platform/dynload/cudnn.h"
+#include "paddle/platform/dynload/curand.h"
+
+#include <cublas_v2.h>
+#include <cudnn.h>
+#include <curand.h>
+#include <thrust/system/cuda/error.h>
+#include <thrust/system_error.h>
+
+#endif  // PADDLE_ONLY_CPU
+
+namespace paddle {
+namespace platform {
+
+// Because most enforce conditions would evaluate to true, we can use
+// __builtin_expect to instruct the C++ compiler to generate code that
+// always forces branch prediction of true.
+// This generates faster binary code. __builtin_expect is since C++11.
+// For more details, please check https://stackoverflow.com/a/43870188/724872.
+#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
+
+#ifndef PADDLE_ONLY_CPU
+
+template <typename... Args>
+inline void throw_on_error(cudaError_t e, const Args&... args) {
+  if (UNLIKELY(e)) {
+    // clang-format off
+    throw thrust::system_error(
+        e, thrust::cuda_category(),
+        string::Sprintf(args...) +
+        string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
+    // clang-format on
+  }
+}
+
+template <typename... Args>
+inline void throw_on_error(curandStatus_t stat, const Args&... args) {
+  if (stat != CURAND_STATUS_SUCCESS) {
+    // clang-format off
+    throw thrust::system_error(
+        cudaErrorLaunchFailure, thrust::cuda_category(),
+        string::Sprintf(args...) +
+        string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
+    // clang-format on
+  }
+}
+
+template <typename... Args>
+inline void throw_on_error(cudnnStatus_t stat, const Args&... args) {
+  if (stat == CUDNN_STATUS_SUCCESS) {
+    return;
+  } else {
+    // clang-format off
+    throw std::runtime_error(
+        platform::dynload::cudnnGetErrorString(stat) +
+        string::Sprintf(args...) +
+        string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
+    // clang-format on
+  }
+}
+
+template <typename... Args>
+inline void throw_on_error(cublasStatus_t stat, const Args&... args) {
+  std::string err;
+  if (stat == CUBLAS_STATUS_SUCCESS) {
+    return;
+  } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
+    err = "CUBLAS: not initialized, ";
+  } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
+    err = "CUBLAS: alloc failed, ";
+  } else if (stat == CUBLAS_STATUS_INVALID_VALUE) {
+    err = "CUBLAS: invalid value, ";
+  } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) {
+    err = "CUBLAS: arch mismatch, ";
+  } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) {
+    err = "CUBLAS: mapping error, ";
+  } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) {
+    err = "CUBLAS: execution failed, ";
+  } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) {
+    err = "CUBLAS: internal error, ";
+  } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) {
+    err = "CUBLAS: not supported, ";
+  } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
+    err = "CUBLAS: license error, ";
+  }
+  throw std::runtime_error(err + string::Sprintf(args...) +
+                           string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
+}
+
+#endif  // PADDLE_ONLY_CPU
+
+template <typename... Args>
+inline void throw_on_error(int stat, const Args&... args) {
+  if (UNLIKELY(!(stat))) {
+    throw std::runtime_error(
+        string::Sprintf(args...) +
+        string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
+  }
+}
+
+#define PADDLE_THROW(...)                                     \
+  do {                                                        \
+    throw std::runtime_error(                                 \
+        string::Sprintf(__VA_ARGS__) +                        \
+        string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \
+  } while (0)
+
+/**
+ * @brief Enforce a condition, otherwise throw an EnforceNotMet
+ */
+#define PADDLE_ENFORCE(condition, ...)                          \
+  do {                                                          \
+    ::paddle::platform::throw_on_error(condition, __VA_ARGS__); \
+  } while (0)
+
+}  // namespace platform
+}  // namespace paddle
diff --git a/paddle/framework/enforce_test.cc b/paddle/platform/enforce_test.cc
similarity index 85%
rename from paddle/framework/enforce_test.cc
rename to paddle/platform/enforce_test.cc
index f8da1a192f63a54324d80725c9d2f156fb11a481..d7152f81509a35e4ce36d5649e7d209f51e34b86 100644
--- a/paddle/framework/enforce_test.cc
+++ b/paddle/platform/enforce_test.cc
@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#include <gtest/gtest.h>
-#include <paddle/framework/enforce.h>
+#include "paddle/platform/enforce.h"
+#include "gtest/gtest.h"
 
 TEST(ENFORCE, OK) {
   PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
@@ -23,13 +23,14 @@ TEST(ENFORCE, FAILED) {
   bool in_catch = false;
   try {
     PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
-  } catch (paddle::framework::EnforceNotMet err) {
+  } catch (const std::runtime_error& error) {
+    // your error handling code here
     in_catch = true;
     std::string msg = "Enforce is not ok 123 at all";
-    const char* what = err.what();
+    const char* what = error.what();
     for (size_t i = 0; i < msg.length(); ++i) {
       ASSERT_EQ(what[i], msg[i]);
     }
   }
   ASSERT_TRUE(in_catch);
-}
\ No newline at end of file
+}
diff --git a/paddle/platform/error.h b/paddle/platform/error.h
deleted file mode 100644
index 93424bb61096503a4843da7942853a113f612e3b..0000000000000000000000000000000000000000
--- a/paddle/platform/error.h
+++ /dev/null
@@ -1,87 +0,0 @@
-#pragma once
-
-#include <sstream>
-#include <stdexcept>
-#include <string>
-
-#ifndef PADDLE_ONLY_CPU
-
-#include <cublas_v2.h>
-#include <cudnn.h>
-#include <curand.h>
-#include <thrust/system/cuda/error.h>
-#include <thrust/system_error.h>
-
-#endif  // PADDLE_ONLY_CPU
-
-namespace paddle {
-namespace platform {
-
-#ifndef PADDLE_ONLY_CPU
-
-inline void throw_on_error(cudaError_t e, const char* message) {
-  if (e) {
-    throw thrust::system_error(e, thrust::cuda_category(), message);
-  }
-}
-
-inline void throw_on_error(curandStatus_t stat, const char* message) {
-  if (stat != CURAND_STATUS_SUCCESS) {
-    throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
-                               message);
-  }
-}
-
-inline void throw_on_error(cudnnStatus_t stat, const char* message) {
-  std::stringstream ss;
-  if (stat == CUDNN_STATUS_SUCCESS) {
-    return;
-  } else {
-    ss << cudnnGetErrorString(stat);
-    ss << ", " << message;
-    throw std::runtime_error(ss.str());
-  }
-}
-
-inline void throw_on_error(cublasStatus_t stat, const char* message) {
-  std::stringstream ss;
-  if (stat == CUBLAS_STATUS_SUCCESS) {
-    return;
-  } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
-    ss << "CUBLAS: not initialized";
-  } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
-    ss << "CUBLAS: alloc failed";
-  } else if (stat == CUBLAS_STATUS_INVALID_VALUE) {
-    ss << "CUBLAS: invalid value";
-  } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) {
-    ss << "CUBLAS: arch mismatch";
-  } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) {
-    ss << "CUBLAS: mapping error";
-  } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) {
-    ss << "CUBLAS: execution failed";
-  } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) {
-    ss << "CUBLAS: internal error";
-  } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) {
-    ss << "CUBLAS: not supported";
-  } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
-    ss << "CUBLAS: license error";
-  }
-  ss << ", " << message;
-  throw std::runtime_error(ss.str());
-}
-
-inline void throw_on_error(cublasStatus_t stat) {
-  const char* message = "";
-  throw_on_error(stat, message);
-}
-
-#endif  // PADDLE_ONLY_CPU
-
-inline void throw_on_error(int stat, const char* message) {
-  if (stat) {
-    throw std::runtime_error(message + (", stat = " + std::to_string(stat)));
-  }
-}
-
-}  // namespace platform
-}  // namespace paddle
diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc
index a1383d3524aedf834c329425419b989d47668bea..cf9921e870d47fe77c0cca80828dbf2bb36ccda8 100644
--- a/paddle/platform/gpu_info.cc
+++ b/paddle/platform/gpu_info.cc
@@ -14,7 +14,7 @@ limitations under the License. */
 
 #include "paddle/platform/gpu_info.h"
 #include "gflags/gflags.h"
-#include "paddle/platform/error.h"
+#include "paddle/platform/enforce.h"
 
 DEFINE_double(fraction_of_gpu_memory_to_use, 0.95,
               "Default use 95% of GPU memory for PaddlePaddle,"
@@ -25,7 +25,7 @@ namespace platform {
 
 int GetDeviceCount() {
   int count;
-  throw_on_error(
+  PADDLE_ENFORCE(
       cudaGetDeviceCount(&count),
       "cudaGetDeviceCount failed in paddle::platform::GetDeviceCount");
   return count;
@@ -33,19 +33,19 @@ int GetDeviceCount() {
 
 int GetCurrentDeviceId() {
   int device_id;
-  throw_on_error(
+  PADDLE_ENFORCE(
       cudaGetDevice(&device_id),
       "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId");
   return device_id;
 }
 
 void SetDeviceId(int id) {
-  throw_on_error(cudaSetDevice(id),
+  PADDLE_ENFORCE(cudaSetDevice(id),
                  "cudaSetDevice failed in paddle::platform::SetDeviceId");
 }
 
 void GpuMemoryUsage(size_t& available, size_t& total) {
-  throw_on_error(cudaMemGetInfo(&available, &total),
+  PADDLE_ENFORCE(cudaMemGetInfo(&available, &total),
                  "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage");
 }
 
diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt
index 00b14a94321990baef6de35df547eed04b3da04f..6354dd211d5d036e1b5971babaf624e8f847a92b 100644
--- a/paddle/pybind/CMakeLists.txt
+++ b/paddle/pybind/CMakeLists.txt
@@ -1,2 +1,2 @@
 cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
-        add_op mul_op rowwise_add_op sigmoid_op softmax_op)
+        add_op fc_op sgd_op)
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index fc9c6544c3cbf5a804b2d052f738bd483d6bf41b..54707a2859693af4a80692bf5cebab59c43ffbc3 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -14,6 +14,7 @@ limitations under the License. */
 
 #include <Python.h>
 #include <paddle/framework/op_registry.h>
+#include <paddle/framework/operator.h>
 #include <paddle/framework/scope.h>
 #include <paddle/pybind/tensor_bind.h>
 #include <pybind11/numpy.h>
@@ -26,10 +27,8 @@ namespace py = pybind11;
 namespace pd = paddle::framework;
 
 USE_OP(add_two);
-USE_OP(softmax);
-USE_OP(mul);
-USE_OP(rowwise_add);
-USE_OP(sigmoid);
+USE_OP_WITHOUT_KERNEL(fc);
+USE_OP(sgd);
 
 PYBIND11_PLUGIN(core) {
   py::module m("core", "C++ core of Paddle Paddle");
@@ -53,7 +52,9 @@ PYBIND11_PLUGIN(core) {
              self.mutable_data<int>(paddle::platform::CPUPlace());
            })
       .def("set", paddle::pybind::PyTensorSetFromArray<float>)
-      .def("set", paddle::pybind::PyTensorSetFromArray<int>);
+      .def("set", paddle::pybind::PyTensorSetFromArray<int>)
+      .def("shape",
+           [](pd::Tensor& self) { return pd::vectorize(self.dims()); });
 
   py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class.
 
@@ -83,15 +84,16 @@ All parameter, weight, gradient are variables in Paddle.
 
   //! @note: Be careful! PyBind will return std::string as an unicode, not
   //! Python str. If you want a str object, you should cast them in Python.
-  m.def("get_all_op_protos", []() -> std::vector<std::string> {
+  m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
     auto& protos = pd::OpRegistry::protos();
-    std::vector<std::string> ret_values;
+    std::vector<py::bytes> ret_values;
     for (auto it = protos.begin(); it != protos.end(); ++it) {
       PADDLE_ENFORCE(it->second.IsInitialized(),
                      "OpProto must all be initialized");
-      ret_values.emplace_back();
-      PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()),
+      std::string str;
+      PADDLE_ENFORCE(it->second.SerializeToString(&str),
                      "Serialize OpProto Error. This could be a bug of Paddle.");
+      ret_values.push_back(py::bytes(str));
     }
     return ret_values;
   });
@@ -101,17 +103,26 @@ All parameter, weight, gradient are variables in Paddle.
       .def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
       .def("temp", pd::OperatorBase::TMP_VAR_NAME);
 
+  py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
+      .def_static("cpu_context", []() -> paddle::platform::DeviceContext* {
+        return new paddle::platform::CPUDeviceContext();
+      });
+
   py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
       .def("__str__", &pd::OperatorBase::DebugString)
-      .def_static("create", [](const std::string& protobin) {
-        pd::OpDesc desc;
-        PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
-                       "Cannot parse user input to OpDesc");
-        PADDLE_ENFORCE(desc.IsInitialized(),
-                       "User OpDesc is not initialized, reason %s",
-                       desc.InitializationErrorString());
-        return pd::OpRegistry::CreateOp(desc);
-      });
+      .def_static("create",
+                  [](py::bytes protobin) {
+                    pd::OpDesc desc;
+                    PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
+                                   "Cannot parse user input to OpDesc");
+                    PADDLE_ENFORCE(desc.IsInitialized(),
+                                   "User OpDesc is not initialized, reason %s",
+                                   desc.InitializationErrorString());
+                    return pd::OpRegistry::CreateOp(desc);
+                  })
+      .def("infer_shape", &pd::OperatorBase::InferShape)
+      .def("run", &pd::OperatorBase::Run)
+      .def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; });
 
   return m.ptr();
 }
diff --git a/paddle/scripts/travis/check_style.sh b/paddle/scripts/travis/check_style.sh
index 8049aeb7b00870220e59c981addf6d70a66877c7..ec499a839ac6593bac788f4cca5e33afbed73010 100755
--- a/paddle/scripts/travis/check_style.sh
+++ b/paddle/scripts/travis/check_style.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 function abort(){
     echo "Your change doesn't follow PaddlePaddle's code style." 1>&2
-    echo "Please use pre-commit to reformat your code and git push again." 1>&2
+    echo "Please use pre-commit to check what is wrong." 1>&2
     exit 1
 }
 
@@ -19,7 +19,8 @@ ln -sf $TRAVIS_BUILD_DIR $GOPATH/src/github.com/PaddlePaddle/Paddle
 cd  $GOPATH/src/github.com/PaddlePaddle/Paddle/go; glide install; cd -
 
 if ! pre-commit run -a ; then
-  git diff  --exit-code
+    git diff
+    exit 1
 fi
 
 trap : 0
diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto
index 37cd16c79890738f6d8966579e15686c653d4df3..83f72c137bdf5e55f28be908321bd2ccd6c906fe 100644
--- a/proto/ModelConfig.proto
+++ b/proto/ModelConfig.proto
@@ -472,10 +472,16 @@ message LayerConfig {
   // blank label used in ctc loss
   optional uint32 blank = 52 [default = 0];
 
-  // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which 
+  // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which
   // controls the scope of pooling operation. can be set > 0.
   // leave empty or set to -1 to disable this stride pooling.
   optional int32 seq_pool_stride = 53 [default = -1];
+
+  // for crop layer
+  optional int32 axis = 54 [default = 2];
+  repeated uint32 offset = 55;
+  repeated uint32 shape = 56;
+
 }
 
 message EvaluatorConfig {
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index 826ba2834a820d11e69feec5569ef3537194e3c3..ab81e67579e39a34e3ace18d14434eb86b66fa5b 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -1575,7 +1575,13 @@ class MultiClassCrossEntropySelfNormCostLayer(LayerBase):
 
 @config_layer('fc')
 class FCLayer(LayerBase):
-    def __init__(self, name, size, inputs, bias=True, **xargs):
+    def __init__(self,
+                 name,
+                 size,
+                 inputs,
+                 bias=True,
+                 error_clipping_threshold=None,
+                 **xargs):
         super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs)
         for input_index in xrange(len(self.inputs)):
             input_layer = self.get_input_layer(input_index)
@@ -1592,6 +1598,8 @@ class FCLayer(LayerBase):
             self.create_input_parameter(input_index, psize, dims, sparse,
                                         format)
         self.create_bias_parameter(bias, self.config.size)
+        if error_clipping_threshold is not None:
+            self.config.error_clipping_threshold = error_clipping_threshold
 
 
 @config_layer('selective_fc')
@@ -1990,6 +1998,23 @@ class PadLayer(LayerBase):
         self.config.size = out_ch * out_h * out_w
 
 
+@config_layer('crop')
+class CropLayer(LayerBase):
+    def __init__(self, name, inputs, axis, offset, shape, **xargs):
+        super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs)
+        self.config.axis = axis
+        self.config.offset.extend(offset)
+        self.config.shape.extend(shape)
+
+        # get channel, width and height from input_0 layer
+        input_layer = self.get_input_layer(0)
+        image_conf = self.config.inputs[0].image_conf
+        image_conf.img_size = input_layer.width
+        image_conf.img_size_y = input_layer.height
+        image_conf.channels = input_layer.size / (input_layer.width *
+                                                  input_layer.height)
+
+
 @config_layer('batch_norm')
 class BatchNormLayer(LayerBase):
     layer_type = 'batch_norm'
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index 78aa0778f8d1dca9fae82f0411be5a00e636cbc9..fdb6f83f2ba510232714fb8a9c7c1af837a753ff 100755
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -127,6 +127,7 @@ __all__ = [
     'dropout_layer',
     'prelu_layer',
     'gated_unit_layer',
+    'crop_layer',
 ]
 
 
@@ -218,6 +219,7 @@ class LayerType(object):
     SMOOTH_L1 = 'smooth_l1'
 
     PRELU = 'prelu'
+    CROP_LAYER = 'crop'
 
     @staticmethod
     def is_layer_type(type_name):
@@ -5970,3 +5972,52 @@ def gated_unit_layer(input,
         name="%s_gated_act" % name,
         input=dotmul_operator(input_proj, gate),
         layer_attr=layer_attr)
+
+
+@wrap_name_default()
+@layer_support()
+def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
+    """
+    The crop layer crops images by offset and shape. User can set crop shape by
+    args 'shape' explicitly or by reference input layer.
+    
+    The example usage is:
+
+    .. code-block:: python
+    crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3])
+
+    :param input: The input layer.If two inputs were setted,
+                    the second input will be regarded as reference input
+    :type input: LayerOutput or Sequence
+    :param offset: The crop offset
+    :type offset: Sequence
+    :param axis: start axis to be cropped. To image input layer:
+        - 0: batch size
+        - 1: channels
+        - 2: height
+        - 3: width
+    :type partial_sum: int
+    :param shape: The shape to be cropped. Default is None.
+    :type shape: Sequence | None
+    :param name: Name of this layer.
+    :type name: basestring
+    :return: LayerOutput object.
+    :rtype: LayerOutput
+    """
+    if isinstance(input, LayerOutput):
+        input = [input]
+    else:
+        assert isinstance(input, collections.Sequence)
+    l = Layer(
+        inputs=[x.name for x in input],
+        axis=axis,
+        offset=offset,
+        shape=shape,
+        name=name,
+        type=LayerType.CROP_LAYER,
+        **ExtraLayerAttribute.to_kwargs(layer_attr))
+    return LayerOutput(
+        name=name,
+        layer_type=LayerType.CROP_LAYER,
+        parents=input,
+        size=l.config.size)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_crop.py b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py
new file mode 100644
index 0000000000000000000000000000000000000000..8314a7e9a5586647c70ff010156817110919c72b
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py
@@ -0,0 +1,21 @@
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=1000, learning_rate=1e-5)
+
+data = data_layer(name='data', size=2016, height=48, width=42)
+refernce_data = data_layer(name='data', size=768, height=16, width=16)
+
+conv = img_conv_layer(
+    input=data,
+    filter_size=3,
+    num_channels=1,
+    num_filters=16,
+    padding=1,
+    act=LinearActivation(),
+    bias_attr=True)
+
+pool = img_pool_layer(input=conv, pool_size=2, stride=2, pool_type=MaxPooling())
+
+crop = crop_layer(input=[pool, refernce_data], axis=2)
+
+outputs(pad)
diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py
index 2e4beb6882789249db09705f3f4d6c5c19e492cd..90830515c1e8e6f5260cfca631e02a3a52cedbe5 100644
--- a/python/paddle/v2/dataset/__init__.py
+++ b/python/paddle/v2/dataset/__init__.py
@@ -26,8 +26,9 @@ import sentiment
 import wmt14
 import mq2007
 import flowers
+import voc2012
 
 __all__ = [
     'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
-    'uci_housing', 'wmt14', 'mq2007', 'flowers'
+    'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012'
 ]
diff --git a/python/paddle/v2/dataset/tests/voc2012_test.py b/python/paddle/v2/dataset/tests/voc2012_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e72ebf5eac0508d12783f9ceaa6eef0fa6d353
--- /dev/null
+++ b/python/paddle/v2/dataset/tests/voc2012_test.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle.v2.dataset.voc2012
+import unittest
+
+
+class TestVOC(unittest.TestCase):
+    def check_reader(self, reader):
+        sum = 0
+        label = 0
+        for l in reader():
+            self.assertEqual(l[0].size, 3 * l[1].size)
+            sum += 1
+        return sum
+
+    def test_train(self):
+        count = self.check_reader(paddle.v2.dataset.voc_seg.train())
+        self.assertEqual(count, 2913)
+
+    def test_test(self):
+        count = self.check_reader(paddle.v2.dataset.voc_seg.test())
+        self.assertEqual(count, 1464)
+
+    def test_val(self):
+        count = self.check_reader(paddle.v2.dataset.voc_seg.val())
+        self.assertEqual(count, 1449)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/v2/dataset/voc2012.py b/python/paddle/v2/dataset/voc2012.py
new file mode 100644
index 0000000000000000000000000000000000000000..617e212d67fbe37f9d9663e9c83c62045411fa77
--- /dev/null
+++ b/python/paddle/v2/dataset/voc2012.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image dataset for segmentation.
+The 2012 dataset contains images from 2008-2011 for which additional
+segmentations have been prepared. As in previous years the assignment
+to training/test sets has been maintained. The total number of images
+with segmentation has been increased from 7,062 to 9,993.
+"""
+
+import tarfile
+import io
+import numpy as np
+from paddle.v2.dataset.common import download
+from paddle.v2.image import *
+from PIL import Image
+
+__all__ = ['train', 'test', 'val']
+
+VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\
+VOCtrainval_11-May-2012.tar'
+
+VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
+SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
+DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
+LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png'
+
+CACHE_DIR = 'voc2012'
+
+
+def reader_creator(filename, sub_name):
+
+    tarobject = tarfile.open(filename)
+    name2mem = {}
+    for ele in tarobject.getmembers():
+        name2mem[ele.name] = ele
+
+    def reader():
+        set_file = SET_FILE.format(sub_name)
+        sets = tarobject.extractfile(name2mem[set_file])
+        for line in sets:
+            line = line.strip()
+            data_file = DATA_FILE.format(line)
+            label_file = LABEL_FILE.format(line)
+            data = tarobject.extractfile(name2mem[data_file]).read()
+            label = tarobject.extractfile(name2mem[label_file]).read()
+            data = Image.open(io.BytesIO(data))
+            label = Image.open(io.BytesIO(label))
+            data = np.array(data)
+            label = np.array(label)
+            yield data, label
+
+    return reader
+
+
+def train():
+    """
+    Create a train dataset reader containing 2913 images in HWC order.
+    """
+    return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'trainval')
+
+
+def test():
+    """
+    Create a test dataset reader containing 1464 images in HWC order.
+    """
+    return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'train')
+
+
+def val():
+    """
+    Create a val dataset reader containing 1449 images in HWC order.
+    """
+    return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'val')
diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/create_op_creation_methods.py
index c2a7ae7692b08762ffbc91726be7bfa90e8ddedb..7248c3f52a9902e8c08ac2f1405801a5710459e5 100644
--- a/python/paddle/v2/framework/create_op_creation_methods.py
+++ b/python/paddle/v2/framework/create_op_creation_methods.py
@@ -217,6 +217,10 @@ def create_op_creation_method(op_proto):
         return core.Operator.create(opdesc.SerializeToString())
 
     __impl__.__doc__ = get_docstring_from_op_proto(op_proto)
+    __impl__.all_input_args = [var.name for var in op_proto.inputs]
+    __impl__.all_output_args = [var.name for var in op_proto.outputs]
+    __impl__.all_attr_args = [attr.name for attr in op_proto.attrs]
+
     return __impl__
 
 
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index 4ce2bef6fcc4b8ddf5a6de3809a1891bce590aab..ec076e40c9312fee7f3ba030dc69208069fd45a8 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -1,3 +1,3 @@
 add_python_test(test_framework test_protobuf.py test_scope.py
     test_default_scope_funcs.py test_op_creation_methods.py
-    test_tensor.py)
+    test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py)
diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1fa12cc89fa724994ea482ab0a3d78c03a9cdf0
--- /dev/null
+++ b/python/paddle/v2/framework/tests/op_test_util.py
@@ -0,0 +1,62 @@
+import paddle.v2.framework.core as core
+import unittest
+import numpy
+import paddle.v2.framework.create_op_creation_methods as creation
+
+
+class OpTestMeta(type):
+    """
+    Operator Test ClassMeta.
+    
+    It injects `test_all` method into user's OperatorTest class, to make Python 
+    unittest module run that method.
+    
+    The `test_all` read what value is stored in `self`. It use self's values to
+    create and run a operator, and check whether that op is OK or not.
+    
+    See `test_add_two_op` for example usage.
+    """
+
+    def __new__(cls, name, bases, attrs):
+        obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs)
+
+        def test_all(self):
+            func = getattr(creation.op_creations, self.type, None)
+            self.assertIsNotNone(func)
+
+            scope = core.Scope(None)
+            kwargs = dict()
+
+            for in_name in func.all_input_args:
+                if hasattr(self, in_name):
+                    kwargs[in_name] = in_name
+                    var = scope.create_var(in_name).get_tensor()
+                    arr = getattr(self, in_name)
+                    var.set_dims(arr.shape)
+                    var.set(arr)
+                else:
+                    kwargs[in_name] = "@EMPTY@"
+
+            for out_name in func.all_output_args:
+                if hasattr(self, out_name):
+                    kwargs[out_name] = out_name
+                    scope.create_var(out_name).get_tensor()
+
+            for attr_name in func.all_attr_args:
+                if hasattr(self, attr_name):
+                    kwargs[attr_name] = getattr(self, attr_name)
+
+            op = func(**kwargs)
+
+            op.infer_shape(scope)
+
+            ctx = core.DeviceContext.cpu_context()
+            op.run(scope, ctx)
+
+            for out_name in func.all_output_args:
+                actual = numpy.array(scope.get_var(out_name).get_tensor())
+                expect = getattr(self, out_name)
+                numpy.testing.assert_almost_equal(actual, expect)
+
+        obj.test_all = test_all
+        return obj
diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..a06d7a78ecf838a49e5f2808d3686c6b92faa8ce
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_add_two_op.py
@@ -0,0 +1,17 @@
+import unittest
+from op_test_util import OpTestMeta
+import numpy
+
+
+class TestAddOp(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "add_two"
+        self.X = numpy.random.random((342, 345)).astype("float32")
+        self.Y = numpy.random.random((342, 345)).astype("float32")
+        self.Out = self.X + self.Y
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..59e7e61249e2a7d49a17e5d87209f03b8f35f730
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_fc_op.py
@@ -0,0 +1,43 @@
+import paddle.v2.framework.core as core
+import unittest
+import numpy
+import paddle.v2.framework.create_op_creation_methods as creation
+
+
+class TestFc(unittest.TestCase):
+    def test_fc(self):
+        scope = core.Scope(None)
+        x = scope.create_var("X")
+        x_tensor = x.get_tensor()
+        x_tensor.set_dims([1000, 784])
+        x_tensor.alloc_float()
+
+        w = scope.create_var("W")
+        w_tensor = w.get_tensor()
+        w_tensor.set_dims([784, 100])
+        w_tensor.alloc_float()
+
+        w_tensor.set(numpy.random.random((784, 100)).astype("float32"))
+
+        # Set a real numpy array here.
+        # x_tensor.set(numpy.array([]))
+
+        op = creation.op_creations.fc(X="X", Y="Y", W="W")
+
+        for out in op.outputs():
+            if scope.get_var(out) is None:
+                scope.create_var(out).get_tensor()
+
+        tensor = scope.get_var("Y").get_tensor()
+        op.infer_shape(scope)
+        self.assertEqual([1000, 100], tensor.shape())
+
+        ctx = core.DeviceContext.cpu_context()
+
+        op.run(scope, ctx)
+
+        # After complete all ops, check Y is expect or not.
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..405d73b224fa153e50b4ec408a921f2bdaab46aa
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_sgd_op.py
@@ -0,0 +1,18 @@
+import unittest
+import numpy
+from op_test_util import OpTestMeta
+
+
+class TestSGD(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "sgd"
+        self.param = numpy.random.random((342, 345)).astype("float32")
+        self.grad = numpy.random.random((342, 345)).astype("float32")
+        self.learning_rate = 0.1
+        self.param_out = self.param - self.learning_rate * self.grad
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/setup.py.in b/python/setup.py.in
index b1041f6102a56f5a200aa909e77729095c052f31..65a26940d4d703ea4fbb5022523a90716982ec10 100644
--- a/python/setup.py.in
+++ b/python/setup.py.in
@@ -20,6 +20,7 @@ setup_requires=["requests",
                 "matplotlib",
                 "rarfile",
                 "scipy>=0.19.0",
+                "Pillow",
                 "nltk"]
 
 if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: