diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index 2af10a996c1ff2829d7b262809a0a78095ffeef9..6788cb34fbaf5941cbb1537c7a83577c623bf76a 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -6,13 +6,17 @@ cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
 nv_test(dim_test SRCS dim_test.cu DEPS ddim)
 
 if (WITH_GPU)
-  nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context)
+  nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto)
 else()
-  cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context)
+  cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto)
 endif ()
 
 cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
-cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
+if (WITH_GPU)
+  nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor)
+else()
+  cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
+endif()
 
 cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
 
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index f8a3be9a82bdbaf82550634d36122eb7bbe85e54..7b6dc09bdb5535488c8c4dbc71c9cd6a7998bd0b 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -189,62 +189,16 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
 
 void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
                        const platform::DeviceContext &dev_ctx) {
-  // TODO(typhoonzero): serialize to ostream
-  {  // the 1st field, uint32_t version
+  {  // the 1st field, uint32_t version for LoDTensor
     constexpr uint32_t version = 0;
     os.write(reinterpret_cast<const char *>(&version), sizeof(version));
   }
-  {  // the 2nd field, tensor description
-     // int32_t  size
-     // void*    protobuf message
-    proto::TensorDesc desc;
-    desc.set_data_type(framework::ToDataType(tensor.type()));
-    auto dims = framework::vectorize(tensor.dims());
-    auto *pb_dims = desc.mutable_dims();
-    pb_dims->Resize(static_cast<int>(dims.size()), 0);
-    std::copy(dims.begin(), dims.end(), pb_dims->begin());
-    int32_t size = desc.ByteSize();
-    os.write(reinterpret_cast<const char *>(&size), sizeof(size));
-    auto out = desc.SerializeAsString();
-    os.write(out.data(), size);
-  }
-  {  // the 3rd field, tensor data
-    uint64_t size = tensor.memory_size();
-    auto *data_ptr = tensor.data<void>();
-    PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
-                   "Index overflow when writing tensor");
-    if (platform::is_gpu_place(tensor.place())) {
-#ifdef PADDLE_WITH_CUDA
-      constexpr size_t kBufSize = 1024 * 1024 * 64;  // 64MB
-      std::unique_ptr<char[]> buf(new char[kBufSize]);
-      auto &gpu_dev_ctx =
-          static_cast<const platform::CUDADeviceContext &>(dev_ctx);
-      platform::CPUPlace cpu;
-      uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
-      while (size != 0) {
-        size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
-        memory::Copy(cpu, buf.get(),
-                     boost::get<platform::CUDAPlace>(tensor.place()),
-                     reinterpret_cast<const void *>(data), size_to_write,
-                     gpu_dev_ctx.stream());
-        gpu_dev_ctx.Wait();
-        os.write(buf.get(), size_to_write);
-        data += size_to_write;
-        size -= size_to_write;
-      }
-#else
-      PADDLE_THROW("Unexpected branch");
-#endif
-    } else {
-      os.write(static_cast<const char *>(data_ptr),
-               static_cast<std::streamsize>(size));
-    }
-  }
-  {  // the 4th field, lod information
-     // uint64_t lod_level
-     // uint64_t lod_level_1 size in byte.
-     // int*     lod_level_1 data
-     // ...
+  {
+    // the 2st field, LoD information
+    // uint64_t lod_level
+    // uint64_t lod_level_1 size in byte.
+    // int*     lod_level_1 data
+    // ...
     auto lod = tensor.lod();
     uint64_t size = lod.size();
     os.write(reinterpret_cast<const char *>(&size), sizeof(size));
@@ -256,49 +210,19 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
                static_cast<std::streamsize>(size));
     }
   }
+  // the 3st field, Tensor
+  SerializeToStream(os, static_cast<Tensor>(tensor), dev_ctx);
 }
 
 void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
-  uint32_t version;
-  is.read(reinterpret_cast<char *>(&version), sizeof(version));
-  PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
-  proto::TensorDesc desc;
-  {  // int32_t size
-     // proto buffer
-    int32_t size;
-    is.read(reinterpret_cast<char *>(&size), sizeof(size));
-    std::unique_ptr<char[]> buf(new char[size]);
-    is.read(reinterpret_cast<char *>(buf.get()), size);
-    PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
-                   "Cannot parse tensor desc");
-  }
-  {  // read tensor
-    std::vector<int64_t> dims;
-    dims.reserve(static_cast<size_t>(desc.dims().size()));
-    std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
-    tensor->Resize(framework::make_ddim(dims));
-
-    void *buf;
-    platform::Place cpu = platform::CPUPlace();
-    switch (desc.data_type()) {
-      case proto::FP32:
-        buf = tensor->mutable_data<float>(cpu);
-        break;
-      case proto::FP64:
-        buf = tensor->mutable_data<double>(cpu);
-        break;
-      case proto::INT32:
-        buf = tensor->mutable_data<int>(cpu);
-        break;
-      case proto::INT64:
-        buf = tensor->mutable_data<int64_t>(cpu);
-        break;
-      default:
-        PADDLE_THROW("DataType %d not supported", desc.data_type());
-    }
-    is.read(static_cast<char *>(buf), tensor->memory_size());
-  }
-  {  // read lod
+  {
+    // the 1st field, unit32_t version for SelectedRows
+    uint32_t version;
+    is.read(reinterpret_cast<char *>(&version), sizeof(version));
+    PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+  }
+  {
+    // the 2st field, LoD information
     uint64_t lod_level;
     is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
     auto &lod = *tensor->mutable_lod();
@@ -312,6 +236,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
       lod[i] = tmp;
     }
   }
+  // the 3st filed, Tensor
+  DeserializeFromStream(is, static_cast<Tensor *>(tensor));
 }
 
 }  // namespace framework
diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc
index 02d84b68233f2fdfc66e1df2fc7ce20307cadd94..0747c8db531d6ae443d76591b945cce0c9bbea2b 100644
--- a/paddle/framework/lod_tensor_test.cc
+++ b/paddle/framework/lod_tensor_test.cc
@@ -126,6 +126,20 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
   EXPECT_NE(t1.data<float>(), lod_tensor_.data<float>());
 }
 
+TEST_F(LoDTensorTester, SerializeAndDeserialize) {
+  LoDTensor dst_tensor;
+  platform::CPUDeviceContext cpu_ctx((platform::CPUPlace()));
+  std::ostringstream oss;
+  SerializeToStream(oss, lod_tensor_, cpu_ctx);
+  std::istringstream iss(oss.str());
+  DeserializeFromStream(iss, &dst_tensor);
+  float* dst_ptr = dst_tensor.mutable_data<float>(platform::CPUPlace());
+  for (int i = 0; i < kLodTensorSize; ++i) {
+    EXPECT_EQ(dst_ptr[i], i);
+  }
+  EXPECT_EQ(dst_tensor.lod(), lod_tensor_.lod());
+}
+
 TEST(LodExpand, test) {
   LoD lod{{0, 2}};
   LoDTensor tensor;
diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc
index c74459c9dd7006a24615b1d6df041583088fb25c..82adfa7123a3cf40d929021602c45fe7d2e34ffa 100644
--- a/paddle/framework/selected_rows.cc
+++ b/paddle/framework/selected_rows.cc
@@ -12,5 +12,58 @@ limitations under the License. */
 #include "paddle/framework/selected_rows.h"
 
 namespace paddle {
-namespace framework {}  // namespace framework
+namespace framework {
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+                       const platform::DeviceContext& dev_ctx) {
+  {  // the 1st field, uint32_t version
+    constexpr uint32_t version = 0;
+    os.write(reinterpret_cast<const char*>(&version), sizeof(version));
+  }
+  {
+    // the 2st field, rows information
+    auto& rows = selected_rows.rows();
+    uint64_t size = rows.size();
+    os.write(reinterpret_cast<const char*>(&size), sizeof(size));
+    for (uint64_t i = 0; i < size; ++i) {
+      os.write(reinterpret_cast<const char*>(&rows[i]), sizeof(rows[i]));
+    }
+  }
+  {
+    // the 3st field, the height of SelectedRows
+    int64_t height = selected_rows.height();
+    os.write(reinterpret_cast<const char*>(&height), sizeof(height));
+  }
+  // the 4st field, Tensor data
+  SerializeToStream(os, selected_rows.value(), dev_ctx);
+}
+
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) {
+  auto tensor = *selected_rows->mutable_value();
+  {
+    // the 1st field, unit32_t version for SelectedRows
+    uint32_t version;
+    is.read(reinterpret_cast<char*>(&version), sizeof(version));
+    PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+  }
+  {
+    // the 2st field, rows information
+    uint64_t size;
+    is.read(reinterpret_cast<char*>(&size), sizeof(size));
+    auto& rows = *selected_rows->mutable_rows();
+    rows.resize(size);
+    for (uint64_t i = 0; i < size; ++i) {
+      is.read(reinterpret_cast<char*>(&rows[i]), sizeof(int64_t));
+    }
+  }
+  {
+    // the 3st field, the height of the SelectedRows
+    int64_t height;
+    is.read(reinterpret_cast<char*>(&height), sizeof(int64_t));
+    selected_rows->set_height(height);
+  }
+  // the 4st field, tensor which contains the data
+  DeserializeFromStream(is, &tensor);
+}
+
+}  // namespace framework
 }  // namespace paddle
diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h
index 0332b91323e3a4b4b80e02302ad3dcafe0986cde..699e392688e9889f050592172f8bfc45f855d0b1 100644
--- a/paddle/framework/selected_rows.h
+++ b/paddle/framework/selected_rows.h
@@ -59,5 +59,14 @@ class SelectedRows {
   int64_t height_;
 };
 
+/*
+ * Serialize/Desiralize SelectedRows to std::ostream
+ * You can pass ofstream or ostringstream to serilize to file
+ * or to a in memory string. GPU tensor will be copied to CPU.
+ */
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+                       const platform::DeviceContext& dev_ctx);
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows);
+
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc
index 4ee13a65d72e44693573397bb686b355effb2227..75487c4010391aa9e519d73058184fa936dabb84 100644
--- a/paddle/framework/selected_rows_test.cc
+++ b/paddle/framework/selected_rows_test.cc
@@ -43,5 +43,19 @@ TEST_F(SelectedRowsTester, complete_dims) {
   ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100}));
 }
 
+TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
+  SelectedRows dst_tensor;
+  platform::CPUDeviceContext cpu_ctx(place_);
+  std::ostringstream oss;
+
+  SerializeToStream(oss, *selected_rows_, cpu_ctx);
+
+  std::istringstream iss(oss.str());
+  DeserializeFromStream(iss, &dst_tensor);
+
+  ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows());
+  ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
+}
+
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index ca76a9fcb9079bab22f7b192c45903852c91797f..a1b4a03289eca4c8b9d8c23ede4221853cb31f79 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -15,12 +15,13 @@
 #include <gtest/gtest.h>
 #include <string>
 
+namespace framework = paddle::framework;
+namespace platform = paddle::platform;
+
 TEST(Tensor, Dims) {
-  using namespace paddle::framework;
-  using namespace paddle::platform;
-  Tensor tt;
+  framework::Tensor tt;
   tt.Resize({2, 3, 4});
-  DDim dims = tt.dims();
+  framework::DDim dims = tt.dims();
   ASSERT_EQ(arity(dims), 3);
   for (int i = 0; i < 3; ++i) {
     EXPECT_EQ(i + 2, dims[i]);
@@ -28,12 +29,12 @@ TEST(Tensor, Dims) {
 }
 
 TEST(Tensor, DataAssert) {
-  paddle::framework::Tensor src_tensor;
+  framework::Tensor src_tensor;
 
   bool caught = false;
   try {
     src_tensor.data<double>();
-  } catch (paddle::platform::EnforceNotMet err) {
+  } catch (platform::EnforceNotMet err) {
     caught = true;
     std::string msg =
         "holder_ should not be null\nTensor holds no memory. Call "
@@ -50,61 +51,65 @@ TEST(Tensor, DataAssert) {
    because Memory::Alloc() and Memory::Free() have not been ready.
 */
 TEST(Tensor, MutableData) {
-  using namespace paddle::framework;
-  using namespace paddle::platform;
   {
-    Tensor src_tensor;
+    framework::Tensor src_tensor;
     float* p1 = nullptr;
     float* p2 = nullptr;
     // initialization
-    p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
+    p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
+                                        platform::CPUPlace());
     EXPECT_NE(p1, nullptr);
     // set src_tensor a new dim with large size
     // momery is supposed to be re-allocated
-    p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace());
+    p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 4}),
+                                        platform::CPUPlace());
     EXPECT_NE(p2, nullptr);
     EXPECT_NE(p1, p2);
     // set src_tensor a new dim with same size
     // momery block is supposed to be unchanged
-    p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CPUPlace());
+    p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
+                                        platform::CPUPlace());
     EXPECT_EQ(p1, p2);
     // set src_tensor a new dim with smaller size
     // momery block is supposed to be unchanged
-    p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
+    p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
+                                        platform::CPUPlace());
     EXPECT_EQ(p1, p2);
   }
 
 #ifdef PADDLE_WITH_CUDA
   {
-    Tensor src_tensor;
+    framework::Tensor src_tensor;
     float* p1 = nullptr;
     float* p2 = nullptr;
     // initialization
-    p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CUDAPlace());
+    p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
+                                        platform::CUDAPlace());
     EXPECT_NE(p1, nullptr);
     // set src_tensor a new dim with large size
     // momery is supposed to be re-allocated
-    p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CUDAPlace());
+    p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 4}),
+                                        platform::CUDAPlace());
     EXPECT_NE(p2, nullptr);
     EXPECT_NE(p1, p2);
     // set src_tensor a new dim with same size
     // momery block is supposed to be unchanged
-    p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CUDAPlace());
+    p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
+                                        platform::CUDAPlace());
     EXPECT_EQ(p1, p2);
     // set src_tensor a new dim with smaller size
     // momery block is supposed to be unchanged
-    p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CUDAPlace());
+    p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
+                                        platform::CUDAPlace());
     EXPECT_EQ(p1, p2);
   }
 #endif
 }
 
 TEST(Tensor, ShareDataWith) {
-  using namespace paddle::framework;
-  using namespace paddle::platform;
   {
-    Tensor src_tensor;
-    Tensor dst_tensor;
+    framework::Tensor src_tensor;
+    framework::Tensor dst_tensor;
     // Try to share data form uninitialized tensor
     bool caught = false;
     try {
@@ -121,16 +126,18 @@ TEST(Tensor, ShareDataWith) {
     }
     ASSERT_TRUE(caught);
 
-    src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
+    src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
+                                 platform::CPUPlace());
     dst_tensor.ShareDataWith(src_tensor);
     ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
   }
 
 #ifdef PADDLE_WITH_CUDA
   {
-    Tensor src_tensor;
-    Tensor dst_tensor;
-    src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CUDAPlace());
+    framework::Tensor src_tensor;
+    framework::Tensor dst_tensor;
+    src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
+                                 platform::CUDAPlace());
     dst_tensor.ShareDataWith(src_tensor);
     ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
   }
@@ -138,13 +145,12 @@ TEST(Tensor, ShareDataWith) {
 }
 
 TEST(Tensor, Slice) {
-  using namespace paddle::framework;
-  using namespace paddle::platform;
   {
-    Tensor src_tensor;
-    src_tensor.mutable_data<int>(make_ddim({5, 3, 4}), CPUPlace());
-    Tensor slice_tensor = src_tensor.Slice(1, 3);
-    DDim slice_dims = slice_tensor.dims();
+    framework::Tensor src_tensor;
+    src_tensor.mutable_data<int>(framework::make_ddim({5, 3, 4}),
+                                 platform::CPUPlace());
+    framework::Tensor slice_tensor = src_tensor.Slice(1, 3);
+    framework::DDim slice_dims = slice_tensor.dims();
     ASSERT_EQ(arity(slice_dims), 3);
     EXPECT_EQ(slice_dims[0], 2);
     EXPECT_EQ(slice_dims[1], 3);
@@ -153,11 +159,12 @@ TEST(Tensor, Slice) {
     uintptr_t src_data_address =
         reinterpret_cast<uintptr_t>(src_tensor.data<int>());
     uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
-        src_tensor.mutable_data<int>(src_tensor.dims(), CPUPlace()));
+        src_tensor.mutable_data<int>(src_tensor.dims(), platform::CPUPlace()));
     uintptr_t slice_data_address =
         reinterpret_cast<uintptr_t>(slice_tensor.data<int>());
-    uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
-        slice_tensor.mutable_data<int>(slice_tensor.dims(), CPUPlace()));
+    uintptr_t slice_mutable_data_address =
+        reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<int>(
+            slice_tensor.dims(), platform::CPUPlace()));
     EXPECT_EQ(src_data_address, src_mutable_data_address);
     EXPECT_EQ(slice_data_address, slice_mutable_data_address);
     EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
@@ -165,22 +172,25 @@ TEST(Tensor, Slice) {
 
 #ifdef PADDLE_WITH_CUDA
   {
-    Tensor src_tensor;
-    src_tensor.mutable_data<double>(make_ddim({6, 9}), CUDAPlace());
-    Tensor slice_tensor = src_tensor.Slice(2, 6);
-    DDim slice_dims = slice_tensor.dims();
+    framework::Tensor src_tensor;
+    src_tensor.mutable_data<double>(framework::make_ddim({6, 9}),
+                                    platform::CUDAPlace());
+    framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
+    framework::DDim slice_dims = slice_tensor.dims();
     ASSERT_EQ(arity(slice_dims), 2);
     EXPECT_EQ(slice_dims[0], 4);
     EXPECT_EQ(slice_dims[1], 9);
 
     uintptr_t src_data_address =
         reinterpret_cast<uintptr_t>(src_tensor.data<double>());
-    uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
-        src_tensor.mutable_data<double>(src_tensor.dims(), CUDAPlace()));
+    uintptr_t src_mutable_data_address =
+        reinterpret_cast<uintptr_t>(src_tensor.mutable_data<double>(
+            src_tensor.dims(), platform::CUDAPlace()));
     uintptr_t slice_data_address =
         reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
-    uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
-        slice_tensor.mutable_data<double>(slice_tensor.dims(), CUDAPlace()));
+    uintptr_t slice_mutable_data_address =
+        reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<double>(
+            slice_tensor.dims(), platform::CUDAPlace()));
     EXPECT_EQ(src_data_address, src_mutable_data_address);
     EXPECT_EQ(slice_data_address, slice_mutable_data_address);
     EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
@@ -189,23 +199,19 @@ TEST(Tensor, Slice) {
 }
 
 TEST(Tensor, ReshapeToMatrix) {
-  using namespace paddle::framework;
-  using namespace paddle::platform;
-  Tensor src;
-  int* src_ptr = src.mutable_data<int>({2, 3, 4, 9}, CPUPlace());
+  framework::Tensor src;
+  int* src_ptr = src.mutable_data<int>({2, 3, 4, 9}, platform::CPUPlace());
   for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
     src_ptr[i] = i;
   }
-  Tensor res = ReshapeToMatrix(src, 2);
+  framework::Tensor res = framework::ReshapeToMatrix(src, 2);
   ASSERT_EQ(res.dims()[0], 2 * 3);
   ASSERT_EQ(res.dims()[1], 4 * 9);
 }
 
 TEST(Tensor, Layout) {
-  using namespace paddle::framework;
-  using namespace paddle::platform;
-  Tensor src;
-  ASSERT_EQ(src.layout(), DataLayout::kNHWC);
-  src.set_layout(DataLayout::kAnyLayout);
-  ASSERT_EQ(src.layout(), DataLayout::kAnyLayout);
+  framework::Tensor src;
+  ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC);
+  src.set_layout(framework::DataLayout::kAnyLayout);
+  ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
 }
diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc
index 293c65a065402adb00571e0651ed1091384789ba..7efc649d0bcda67c663d148e83bcbb6789b0f371 100644
--- a/paddle/framework/tensor_util.cc
+++ b/paddle/framework/tensor_util.cc
@@ -31,6 +31,7 @@ struct AnyDTypeVisitor {
   void operator()() const {
     auto t = EigenVector<T>::Flatten(tensor_);
     auto o = EigenScalar<bool>::From(*out_);
+    // return any of predicate_(t) is true.
     o.device(*ctx_.eigen_device()) = predicate_(t).any();
   }
 };
@@ -66,9 +67,10 @@ struct AnyVisitor : public boost::static_visitor<bool> {
     framework::Tensor tmp;
     tmp.Resize({1});
     tmp.mutable_data<bool>(cpu);
-    platform::DeviceContextPool::Instance().Get(gpu)->Wait();
-    CopyFrom(out, cpu, &tmp);
-    platform::DeviceContextPool::Instance().Get(gpu)->Wait();
+    auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu);
+    gpuctx->Wait();
+    CopyFrom(out, cpu, *gpuctx, &tmp);
+    gpuctx->Wait();
     return GetResult(tmp, cpu);
   }
 
@@ -89,6 +91,7 @@ struct HasNANPredicate {
   template <typename T>
   auto operator()(const T& eigen_vec) const
       -> decltype(std::declval<T>().isnan()) {
+    // Cast eigen_vector to vector of bool. true if is inf.
     return eigen_vec.isnan();
   }
 };
@@ -102,6 +105,7 @@ struct HasInfPredicate {
   template <typename T>
   auto operator()(const T& eigen_vec) const
       -> decltype(std::declval<T>().isinf()) {
+    // Cast eigen_vector to vector of bool. true if is inf.
     return eigen_vec.isinf();
   }
 };
diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h
index e71d8e5672a3323ce789afa2d575517dc596b12c..6a21f8db1e3966fd23eee0da2346b2d61f9321fb 100644
--- a/paddle/framework/tensor_util.h
+++ b/paddle/framework/tensor_util.h
@@ -15,6 +15,7 @@ limitations under the License. */
 #pragma once
 #include "paddle/framework/data_type.h"
 #include "paddle/framework/eigen.h"
+#include "paddle/framework/framework.pb.h"
 #include "paddle/framework/tensor.h"
 #include "paddle/platform/device_context.h"
 
@@ -208,8 +209,109 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
                src_ptr, size);
 }
 
-extern bool HasNAN(const framework::Tensor& tensor);
-extern bool HasInf(const framework::Tensor& tensor);
+// Returns true if a tensor contains NAN, i.e., Not A Number.
+bool HasNAN(const framework::Tensor& tensor);
+
+// Returns true if a tensor contains Inf, i.e., Infinity.
+bool HasInf(const framework::Tensor& tensor);
+
+inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
+                              const platform::DeviceContext& dev_ctx) {
+  // TODO(typhoonzero): serialize to ostream
+  {  // the 1st field, uint32_t version
+    constexpr uint32_t version = 0;
+    os.write(reinterpret_cast<const char*>(&version), sizeof(version));
+  }
+  {  // the 2nd field, tensor description
+     // int32_t  size
+     // void*    protobuf message
+    proto::TensorDesc desc;
+    desc.set_data_type(framework::ToDataType(tensor.type()));
+    auto dims = framework::vectorize(tensor.dims());
+    auto* pb_dims = desc.mutable_dims();
+    pb_dims->Resize(static_cast<int>(dims.size()), 0);
+    std::copy(dims.begin(), dims.end(), pb_dims->begin());
+    int32_t size = desc.ByteSize();
+    os.write(reinterpret_cast<const char*>(&size), sizeof(size));
+    auto out = desc.SerializeAsString();
+    os.write(out.data(), size);
+  }
+  {  // the 3rd field, tensor data
+    uint64_t size = tensor.memory_size();
+    auto* data_ptr = tensor.data<void>();
+    PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
+                   "Index overflow when writing tensor");
+    if (platform::is_gpu_place(tensor.place())) {
+#ifdef PADDLE_WITH_CUDA
+      constexpr size_t kBufSize = 1024 * 1024 * 64;  // 64MB
+      std::unique_ptr<char[]> buf(new char[kBufSize]);
+      auto& gpu_dev_ctx =
+          static_cast<const platform::CUDADeviceContext&>(dev_ctx);
+      platform::CPUPlace cpu;
+      uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
+      while (size != 0) {
+        size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
+        memory::Copy(cpu, buf.get(),
+                     boost::get<platform::CUDAPlace>(tensor.place()),
+                     reinterpret_cast<const void*>(data), size_to_write,
+                     gpu_dev_ctx.stream());
+        gpu_dev_ctx.Wait();
+        os.write(buf.get(), size_to_write);
+        data += size_to_write;
+        size -= size_to_write;
+      }
+#else
+      PADDLE_THROW("Unexpected branch");
+#endif
+    } else {
+      os.write(static_cast<const char*>(data_ptr),
+               static_cast<std::streamsize>(size));
+    }
+  }
+}
+
+inline void DeserializeFromStream(std::istream& is, Tensor* tensor) {
+  uint32_t version;
+  is.read(reinterpret_cast<char*>(&version), sizeof(version));
+  PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+  proto::TensorDesc desc;
+  {  // int32_t size
+     // proto buffer
+    int32_t size;
+    is.read(reinterpret_cast<char*>(&size), sizeof(size));
+    std::unique_ptr<char[]> buf(new char[size]);
+    is.read(reinterpret_cast<char*>(buf.get()), size);
+    PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
+                   "Cannot parse tensor desc");
+  }
+  {  // read tensor
+    std::vector<int64_t> dims;
+    dims.reserve(static_cast<size_t>(desc.dims().size()));
+    std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
+    tensor->Resize(framework::make_ddim(dims));
+
+    void* buf;
+    platform::Place cpu = platform::CPUPlace();
+    // TODO(Yancey1989): use VisiterDataType instead of DataType switch
+    switch (desc.data_type()) {
+      case proto::FP32:
+        buf = tensor->mutable_data<float>(cpu);
+        break;
+      case proto::FP64:
+        buf = tensor->mutable_data<double>(cpu);
+        break;
+      case proto::INT32:
+        buf = tensor->mutable_data<int>(cpu);
+        break;
+      case proto::INT64:
+        buf = tensor->mutable_data<int64_t>(cpu);
+        break;
+      default:
+        PADDLE_THROW("DataType %d not supported", desc.data_type());
+    }
+    is.read(static_cast<char*>(buf), tensor->memory_size());
+  }
+}
 
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc
index 01dfd4deb9d126b9e28391b3643156cd1b0dfc9d..0dc5166fcabf77b48b8681ab1f050e2bc88f44ab 100644
--- a/paddle/framework/tensor_util_test.cc
+++ b/paddle/framework/tensor_util_test.cc
@@ -231,7 +231,7 @@ TEST(CopyToVector, Tensor) {
 #endif
 }
 
-TEST(IsNAN, CPU) {
+TEST(HasNAN, CPU) {
   using namespace paddle::framework;
   using namespace paddle::platform;
   Tensor src;
@@ -243,7 +243,7 @@ TEST(IsNAN, CPU) {
   ASSERT_TRUE(HasNAN(src));
 }
 
-TEST(IsInf, CPU) {
+TEST(HasInf, CPU) {
   using namespace paddle::framework;
   using namespace paddle::platform;
   Tensor src;
@@ -254,5 +254,55 @@ TEST(IsInf, CPU) {
   ASSERT_TRUE(HasInf(src));
 }
 
+TEST(Tensor, SerializeAndDeserialize) {
+  framework::Tensor src_tensor;
+  int array[6] = {1, 2, 3, 4, 5, 6};
+  src_tensor.Resize({2, 3});
+  int* src_ptr = src_tensor.mutable_data<int>(platform::CPUPlace());
+  for (int i = 0; i < 6; ++i) {
+    src_ptr[i] = array[i];
+  }
+  {
+    framework::Tensor dst_tensor;
+    auto place = new platform::CPUPlace();
+    platform::CPUDeviceContext cpu_ctx(*place);
+    std::ostringstream oss;
+    SerializeToStream(oss, src_tensor, cpu_ctx);
+
+    std::istringstream iss(oss.str());
+    DeserializeFromStream(iss, &dst_tensor);
+    int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
+    for (int i = 0; i < 5; ++i) {
+      ASSERT_EQ(dst_ptr[i], array[i]);
+    }
+    delete place;
+  }
+#ifdef PADDLE_WITH_CUDA
+  {
+    Tensor gpu_tensor;
+    gpu_tensor.Resize({2, 3});
+    Tensor dst_tensor;
+
+    auto gpu_place = new platform::CUDAPlace();
+    platform::CUDADeviceContext gpu_ctx(*gpu_place);
+
+    CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
+
+    std::ostringstream oss;
+    SerializeToStream(oss, gpu_tensor, gpu_ctx);
+
+    std::istringstream iss(oss.str());
+    DeserializeFromStream(iss, &dst_tensor);
+
+    int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
+    for (int i = 0; i < 6; ++i) {
+      ASSERT_EQ(dst_ptr[i], array[i]);
+    }
+
+    delete gpu_place;
+  }
+#endif
+}
+
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/framework/tensor_util_test.cu b/paddle/framework/tensor_util_test.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ebd35fdf6c2a1388fec23057070f723c8ef9da9c
--- /dev/null
+++ b/paddle/framework/tensor_util_test.cu
@@ -0,0 +1,57 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "gtest/gtest.h"
+#include "paddle/framework/tensor_util.h"
+#include "paddle/platform/device_context.h"
+#include "paddle/platform/place.h"
+
+namespace paddle {
+namespace framework {
+
+static __global__ void FillNAN(float* buf) {
+  buf[0] = 0.0;
+  buf[1] = 0.1;
+  buf[2] = NAN;
+}
+static __global__ void FillInf(float* buf) {
+  buf[0] = 0.0;
+  buf[1] = INFINITY;
+  buf[2] = 0.5;
+}
+
+TEST(HasNAN, GPU) {
+  Tensor tensor;
+  platform::CUDAPlace gpu(0);
+  auto& pool = platform::DeviceContextPool::Instance();
+  auto* cuda_ctx = pool.GetByPlace(gpu);
+  float* buf = tensor.mutable_data<float>({3}, gpu);
+  FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
+  cuda_ctx->Wait();
+  ASSERT_TRUE(HasNAN(tensor));
+}
+
+TEST(HasInf, GPU) {
+  Tensor tensor;
+  platform::CUDAPlace gpu(0);
+  auto& pool = platform::DeviceContextPool::Instance();
+  auto* cuda_ctx = pool.GetByPlace(gpu);
+  float* buf = tensor.mutable_data<float>({3}, gpu);
+  FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
+  cuda_ctx->Wait();
+  ASSERT_TRUE(HasInf(tensor));
+}
+
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h
index 5f6b2d458f7ee764c22d203f285b78023b6012f3..bcd8190755083ec30687675602a1c95a9c15c69e 100644
--- a/paddle/framework/threadpool.h
+++ b/paddle/framework/threadpool.h
@@ -16,6 +16,7 @@ limitations under the License. */
 
 #include <condition_variable>
 #include <functional>
+#include <future>
 #include <mutex>
 #include <queue>
 #include <thread>
@@ -25,10 +26,11 @@ limitations under the License. */
 namespace paddle {
 namespace framework {
 
-typedef std::function<void()> Task;
-
 class ThreadPool {
  public:
+  typedef std::packaged_task<void()> Task;
+  typedef std::function<void()> Fun;
+
   /**
    * @brief   Get a instance of threadpool, the thread number will
    *          be specified as the number of hardware thread contexts
@@ -61,13 +63,18 @@ class ThreadPool {
   /**
    * @brief   Push a function to the queue, and will be scheduled and
    *          executed if a thread is available.
-   * @param[in] Task  will be pushed to the task queue.
+   * @param[in] Task, will be pushed to the task queue.
+   * @return    std::future<void>, we could wait for the task finished by
+   *            f.wait().
    */
-  void Run(const Task& fn) {
+  std::future<void> Run(const Fun& fn) {
     std::unique_lock<std::mutex> lock(mutex_);
-    tasks_.push(fn);
+    Task task(std::bind(fn));
+    std::future<void> f = task.get_future();
+    tasks_.push(std::move(task));
     lock.unlock();
     scheduled_.notify_one();
+    return f;
   }
 
   /**
@@ -110,7 +117,7 @@ class ThreadPool {
         break;
       }
       // pop a task from the task queue
-      auto task = tasks_.front();
+      auto task = std::move(tasks_.front());
       tasks_.pop();
 
       --available_;
diff --git a/paddle/framework/threadpool_test.cc b/paddle/framework/threadpool_test.cc
index 012d92a5edc415f0bb2f8a0ea38ffeb9549d54fa..50b6238cd8786be9d8cf2d5f821daadea12bd208 100644
--- a/paddle/framework/threadpool_test.cc
+++ b/paddle/framework/threadpool_test.cc
@@ -20,16 +20,21 @@ limitations under the License. */
 namespace framework = paddle::framework;
 
 void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) {
+  std::vector<std::future<void>> fs;
   for (int i = 0; i < cnt; ++i) {
-    pool->Run([&sum]() { sum.fetch_add(1); });
+    auto f = pool->Run([&sum]() { sum.fetch_add(1); });
+    fs.push_back(std::move(f));
+  }
+  for (auto& f : fs) {
+    f.wait();
   }
 }
 
 TEST(ThreadPool, ConcurrentInit) {
   framework::ThreadPool* pool;
-  int concurrent_cnt = 50;
+  int n = 50;
   std::vector<std::thread> threads;
-  for (int i = 0; i < concurrent_cnt; ++i) {
+  for (int i = 0; i < n; ++i) {
     std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); });
     threads.push_back(std::move(t));
   }
@@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) {
   }
 }
 
-TEST(ThreadPool, ConcurrentStart) {
+TEST(ThreadPool, ConcurrentRun) {
   framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
   std::atomic<int> sum(0);
   std::vector<std::thread> threads;
-  int concurrent_cnt = 50;
+  int n = 50;
   // sum = (n * (n + 1)) / 2
-  for (int i = 1; i <= concurrent_cnt; ++i) {
+  for (int i = 1; i <= n; ++i) {
     std::thread t(do_sum, pool, std::ref(sum), i);
     threads.push_back(std::move(t));
   }
@@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) {
     t.join();
   }
   pool->Wait();
-  EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2);
+  EXPECT_EQ(sum, ((n + 1) * n) / 2);
 }
diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc
index ab52a41b539236f1691ce8bc02d31e336ee4ccbb..e65a5dce52c3c51d3d6bee1684c1e97230203d38 100644
--- a/paddle/operators/conv_op.cc
+++ b/paddle/operators/conv_op.cc
@@ -31,8 +31,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
   std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
   int groups = ctx->Attrs().Get<int>("groups");
   std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
-  int input_channels = in_dims[1];
-  int output_channels = filter_dims[0];
 
   PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
                  "Conv intput should be 4-D or 5-D tensor.");
@@ -45,9 +43,13 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
   PADDLE_ENFORCE_EQ(
       paddings.size(), strides.size(),
       "Conv paddings dimension and Conv strides dimension should be the same.");
+
+  int input_channels = in_dims[1];
   PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
                     "The number of input channels should be equal to filter "
                     "channels * groups.");
+
+  int output_channels = filter_dims[0];
   PADDLE_ENFORCE_EQ(
       output_channels % groups, 0,
       "The number of output channels should be divided by groups.");
diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc
index 65f021d91931541b712bd46aebc06e68144b2af0..08b972a233aab8596a5ce7f74ea903df3b8ef0f2 100644
--- a/paddle/operators/load_op.cc
+++ b/paddle/operators/load_op.cc
@@ -38,7 +38,7 @@ class LoadOp : public framework::OperatorBase {
                    out_var_name);
 
     auto *tensor = out_var->GetMutable<framework::LoDTensor>();
-    framework::DeserializeFromStream(fin, tensor);
+    DeserializeFromStream(fin, tensor);
 
     platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
     auto &dev_ctx = *pool.Get(place);
diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt
index bf47879f772a3013bd7ce78c6f8a6aefe65298f9..b97faec4ed687c1cf8d746cdf615e86fd79ca921 100644
--- a/paddle/operators/math/CMakeLists.txt
+++ b/paddle/operators/math/CMakeLists.txt
@@ -9,9 +9,9 @@ if(WITH_GPU)
     nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context)
     nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
     nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
-    nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
+    nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor)
     nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
-    nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
+    nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor)
     nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
     nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
     nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
@@ -23,9 +23,9 @@ else()
     cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context)
     cc_library(pooling SRCS pooling.cc DEPS device_context)
     cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
-    cc_library(vol2col SRCS vol2col.cc DEPS device_context)
+    cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor)
     cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
-    cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
+    cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor)
     cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
     cc_library(maxouting SRCS maxouting.cc DEPS device_context)
     cc_library(unpooling SRCS unpooling.cc DEPS device_context)
diff --git a/paddle/platform/for_range.h b/paddle/platform/for_range.h
index 5427aa28238d6b46eb72d1fb49303dce3d871d7d..694a66d9ac4eb6ad02daf1931806fa1287de7cab 100644
--- a/paddle/platform/for_range.h
+++ b/paddle/platform/for_range.h
@@ -62,7 +62,7 @@ struct ForRange<CUDADeviceContext> {
 
   template <typename Function>
   inline void operator()(Function func) const {
-    constexpr size_t num_threads = 1024;
+    constexpr int num_threads = 1024;
     int block_size = limit_ <= num_threads ? limit_ : num_threads;
     int grid_size = (limit_ + num_threads - 1) / num_threads;
 
diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt
index 6afed7eec7001b646d55cef0bc3f59782b80b15f..ced75cbfd899980390d41610d863d6cf154570b0 100644
--- a/paddle/pybind/CMakeLists.txt
+++ b/paddle/pybind/CMakeLists.txt
@@ -3,6 +3,7 @@ if(WITH_PYTHON)
     SRCS pybind.cc exception.cc protobuf.cc const_value.cc
     DEPS pybind python backward proto_desc paddle_memory executor prune init
     ${GLOB_OP_LIB})
+  target_link_libraries(paddle_pybind rt)
 endif(WITH_PYTHON)
 
 if(WITH_DOC)
diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh
index e43b9c218a3ecb9e7f20fb7e8b14a85a29947eef..92039ec6b05d224e702f0ba5dc05c057a492287e 100644
--- a/paddle/scripts/docker/build.sh
+++ b/paddle/scripts/docker/build.sh
@@ -178,7 +178,7 @@ EOF
     # run paddle version to install python packages first
     RUN apt-get update &&\
         ${NCCL_DEPS}\
-        apt-get install -y wget python-pip dmidecode && pip install -U pip && \
+        apt-get install -y wget python-pip dmidecode python-tk && pip install -U pip && \
         pip install /*.whl; apt-get install -f -y && \
         apt-get clean -y && \
         rm -f /*.whl && \
diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py
index 634388094c804827657dc83d5c205e680625b156..7bdddeaabec733ef26b3f766c6437f5c53d65044 100644
--- a/python/paddle/v2/dataset/flowers.py
+++ b/python/paddle/v2/dataset/flowers.py
@@ -44,7 +44,7 @@ __all__ = ['train', 'test', 'valid']
 DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
 LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
 SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
-DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
+DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118'
 LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
 SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
 # In official 'readme', tstid is the flag of test data
diff --git a/python/paddle/v2/fluid/data_feeder.py b/python/paddle/v2/fluid/data_feeder.py
index 30a542af212926c93381aade426e25f2117e4662..24036c3e75b9594ba58cccb02825ab8020d1e107 100644
--- a/python/paddle/v2/fluid/data_feeder.py
+++ b/python/paddle/v2/fluid/data_feeder.py
@@ -3,7 +3,7 @@ import core
 import numpy
 import six.moves as six
 
-from framework import Variable
+from framework import Variable, default_main_program
 
 __all__ = ['DataFeeder']
 
@@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object):
 
 
 class DataFeeder(object):
-    def __init__(self, feed_list, place):
+    def __init__(self, feed_list, place, program=None):
         self.feed_dtypes = []
         self.feed_names = []
         self.feed_shapes = []
         self.feed_lod_level = []
+        if program is None:
+            program = default_main_program()
         for each_var in feed_list:
+            if isinstance(each_var, basestring):
+                each_var = program.block(0).var(each_var)
             if not isinstance(each_var, Variable):
                 raise TypeError("Feed list should contain a list of variable")
             self.feed_dtypes.append(each_var.dtype)
diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py
index 69a732fc45a1946f260cdd9a9c2da150b87c3ddd..c47ce82aba7fa5ac42ac26cd25fa3ebc93e96cb2 100644
--- a/python/paddle/v2/fluid/io.py
+++ b/python/paddle/v2/fluid/io.py
@@ -188,7 +188,7 @@ def save_inference_model(dirname,
             raise ValueError("'feed_var_names' should be a list of str.")
 
     if isinstance(target_vars, Variable):
-        feeded_var_names = [feeded_var_names]
+        target_vars = [target_vars]
     else:
         if not (bool(target_vars) and all(
                 isinstance(var, Variable) for var in target_vars)):