diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc
index 0f78eeab9bc643a1a70c4b6ab02a160bbeda2b33..2acb96d1b4f5903ff6c57b10e7621c8adaf73171 100644
--- a/paddle/operators/sgd_op.cc
+++ b/paddle/operators/sgd_op.cc
@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
  public:
   using framework::OperatorWithKernel::OperatorWithKernel;
 
-  void InferShape(framework::InferShapeContext *ctx) const override {
+  void InferShape(framework::InferShapeContext* ctx) const override {
     PADDLE_ENFORCE(ctx->HasInput("Param"),
                    "Input(Param) of SGDOp should not be null.");
     PADDLE_ENFORCE(ctx->HasInput("Grad"),
@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
                       "Learning rate should have 1 element");
     auto param_dim = ctx->GetInputDim("Param");
-    PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
-                      "Two input of SGD Op's dimension must be same.");
+    // TODO(qijun): check dimensions of Param and Grad at complie
+    // and run time.
     ctx->SetOutputDim("ParamOut", param_dim);
   }
 };
 
 class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
  public:
-  SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+  SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
       : OpProtoAndCheckerMaker(proto, op_checker) {
     AddInput("Param", "Input parameter");
     AddInput("LearningRate", "Learning rate of SGD");
@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad;
 )DOC");
   }
 };
+
+template <typename T>
+struct SparseSGDFunctor<platform::CPUPlace, T> {
+  void operator()(const platform::DeviceContext& context,
+                  const framework::SelectedRows& input,
+                  const framework::Tensor& learning_rate,
+                  framework::Tensor* output) {
+    auto in_height = input.height();
+    auto out_dims = output->dims();
+    PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
+
+    auto& in_value = input.value();
+    auto& in_rows = input.rows();
+
+    int64_t in_row_numel = in_value.numel() / in_rows.size();
+    PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
+
+    auto* in_data = in_value.data<T>();
+    auto* out_data = output->data<T>();
+    auto* lr = learning_rate.data<T>();
+
+    for (size_t i = 0; i < in_rows.size(); i++) {
+      for (int64_t j = 0; j < in_row_numel; j++) {
+        out_data[in_rows[i] * in_row_numel + j] -=
+            lr[0] * in_data[i * in_row_numel + j];
+      }
+    }
+  }
+};
+
+template struct SparseSGDFunctor<platform::CPUPlace, float>;
+
 }  // namespace operators
 }  // namespace paddle
 
diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu
index f5ba6d3c29f8dfbfdea4fbf2c3d5fd7f5b358666..106f9b746ba6614d8fa68b677c47ec04ed26fb81 100644
--- a/paddle/operators/sgd_op.cu
+++ b/paddle/operators/sgd_op.cu
@@ -14,6 +14,66 @@
 
 #define EIGEN_USE_GPU
 #include "paddle/operators/sgd_op.h"
+#include "paddle/platform/cuda_helper.h"
+
+namespace paddle {
+namespace operators {
+
+namespace {
+template <typename T>
+__global__ void SparseSGDFunctorKernel(const T* selected_rows,
+                                       const int64_t* rows,
+                                       const T* learning_rate, T* tensor_out,
+                                       int64_t row_numel, int block_size) {
+  const int ty = blockIdx.y;
+  int tid = threadIdx.x;
+
+  selected_rows += ty * row_numel;
+  tensor_out += rows[ty] * row_numel;
+
+  for (int index = tid; index < row_numel; index += block_size) {
+    // Since index in rows of SelectedRows can be duplicate, we have to use
+    // Atomic Operation to avoid concurrent write error.
+    paddle::platform::CudaAtomicAdd(
+        tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
+  }
+}
+}  // namespace
+
+template <typename T>
+struct SparseSGDFunctor<platform::GPUPlace, T> {
+  void operator()(const platform::DeviceContext& context,
+                  const framework::SelectedRows& input,
+                  const framework::Tensor& learning_rate,
+                  framework::Tensor* output) {
+    auto in_height = input.height();
+    auto out_dims = output->dims();
+    PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
+
+    auto& in_value = input.value();
+    auto& in_rows = input.rows();
+
+    int64_t in_row_numel = in_value.numel() / in_rows.size();
+    PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
+
+    auto* in_data = in_value.data<T>();
+    auto* out_data = output->data<T>();
+
+    int block_size = 256;
+    dim3 threads(block_size, 1);
+    dim3 grid(1, in_rows.size());
+    SparseSGDFunctorKernel<
+        T><<<grid, threads, 0,
+             reinterpret_cast<const platform::CUDADeviceContext&>(context)
+                 .stream()>>>(in_data, in_rows.data(), learning_rate.data<T>(),
+                              out_data, in_row_numel, block_size);
+  }
+};
+
+template struct SparseSGDFunctor<platform::GPUPlace, float>;
+
+}  // namespace operators
+}  // namespace paddle
 
 namespace ops = paddle::operators;
 REGISTER_OP_GPU_KERNEL(sgd,
diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h
index 26f4012f258771794c736dbfad4af174b017f410..78b595fc6c63d775b627f23cafa9458f1dadd4e5 100644
--- a/paddle/operators/sgd_op.h
+++ b/paddle/operators/sgd_op.h
@@ -15,31 +15,53 @@ limitations under the License. */
 #pragma once
 #include "paddle/framework/eigen.h"
 #include "paddle/framework/op_registry.h"
+#include "paddle/framework/selected_rows.h"
 
 namespace paddle {
 namespace operators {
 
+template <typename Place, typename T>
+struct SparseSGDFunctor {
+  void operator()(const platform::DeviceContext& context,
+                  const framework::SelectedRows& input,
+                  const framework::Tensor& learning_rate,
+                  framework::Tensor* output);
+};
+
 template <typename Place, typename T>
 class SGDOpKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    auto param = ctx.Input<framework::Tensor>("Param");
-    auto grad = ctx.Input<framework::Tensor>("Grad");
-    auto param_out = ctx.Output<framework::Tensor>("ParamOut");
-    auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
+    auto* param = ctx.Input<framework::Tensor>("Param");
+    auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
+    auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
 
-    param_out->mutable_data<T>(ctx.GetPlace());
+    auto* grad_var = ctx.InputVar("Grad");
+    // Actually, all tensors are LoDTensor except SelectedRows.
+    if (grad_var->IsType<framework::LoDTensor>()) {
+      param_out->mutable_data<T>(ctx.GetPlace());
+      auto* grad = ctx.Input<framework::Tensor>("Grad");
 
-    auto p = framework::EigenVector<T>::Flatten(*param);
-    auto g = framework::EigenVector<T>::Flatten(*grad);
-    auto o = framework::EigenVector<T>::Flatten(*param_out);
-    auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
-    auto place = ctx.GetEigenDevice<Place>();
+      auto p = framework::EigenVector<T>::Flatten(*param);
+      auto g = framework::EigenVector<T>::Flatten(*grad);
+      auto o = framework::EigenVector<T>::Flatten(*param_out);
+      auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
+      auto place = ctx.GetEigenDevice<Place>();
 
-    Eigen::DSizes<int, 1> grad_dsize(grad->numel());
-    o.device(place) = p - lr.broadcast(grad_dsize) * g;
+      Eigen::DSizes<int, 1> grad_dsize(grad->numel());
+      o.device(place) = p - lr.broadcast(grad_dsize) * g;
+    } else if (grad_var->IsType<framework::SelectedRows>()) {
+      // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
+      // This manual optimization brings difficulty to track data dependency.
+      // It's better to find a more elegant solution.
+      PADDLE_ENFORCE_EQ(param, param_out);
+      auto* grad = ctx.Input<framework::SelectedRows>("Grad");
+      SparseSGDFunctor<Place, T> functor;
+      functor(ctx.device_context(), *grad, *learning_rate, param_out);
+    } else {
+      PADDLE_THROW("Unsupported Variable Type of Grad");
+    }
   }
 };
-
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index 9eb1bf4a16ef40bb3044f46db9777fd2f6c341d2..16661b93e56da30ecd3848d28a0f4667b710e80c 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -154,7 +154,15 @@ PYBIND11_PLUGIN(core) {
            py::return_value_policy::reference)
       .def("set_height", &SelectedRows::set_height)
       .def("height", &SelectedRows::height)
-      .def("set_rows", &SelectedRows::set_rows)
+      .def("set_rows",
+           [](SelectedRows &self, std::vector<int64_t> rows) {
+#ifndef PADDLE_WITH_CUDA
+             self.set_rows(rows);
+#else
+        Vector<int64_t> new_rows(rows);
+        self.set_rows(new_rows);
+#endif
+           })
       .def("rows", [](SelectedRows &self) {
 #ifndef PADDLE_WITH_CUDA
         return self.rows();
@@ -187,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle.
              return self.GetMutable<LoDTensor>();
            },
            py::return_value_policy::reference)
+      .def("get_selected_rows",
+           [](Variable &self) -> SelectedRows * {
+             return self.GetMutable<SelectedRows>();
+           },
+           py::return_value_policy::reference)
       .def("get_net",
            [](Variable &self) -> operators::NetOp * {
              return self.GetMutable<operators::NetOp>();
diff --git a/python/paddle/v2/framework/tests/test_selected_rows.py b/python/paddle/v2/framework/tests/test_selected_rows.py
index 661e81817951f5605ba3ca7fb0cc667074b1e37c..e8a930cb08c42b48f678bdd7bdb7698923535d4f 100644
--- a/python/paddle/v2/framework/tests/test_selected_rows.py
+++ b/python/paddle/v2/framework/tests/test_selected_rows.py
@@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase):
         place = core.CPUPlace()
         height = 10
         rows = [0, 4, 7]
-        row_numel = 10
-        selcted_rows = core.SelectedRows(rows, row_numel)
-        np_array = np.ones((len(rows), height)).astype("float32")
+        row_numel = 12
+        selected_rows = core.SelectedRows(rows, height)
+        np_array = np.ones((len(rows), row_numel)).astype("float32")
         np_array[0, 0] = 2.0
         np_array[2, 8] = 4.0
-        tensor = selcted_rows.get_tensor()
+        tensor = selected_rows.get_tensor()
         tensor.set(np_array, place)
 
         # compare rows
-        self.assertEqual(0, selcted_rows.rows()[0])
-        self.assertEqual(4, selcted_rows.rows()[1])
-        self.assertEqual(7, selcted_rows.rows()[2])
+        self.assertEqual(0, selected_rows.rows()[0])
+        self.assertEqual(4, selected_rows.rows()[1])
+        self.assertEqual(7, selected_rows.rows()[2])
 
         # compare height
-        self.assertEqual(10, selcted_rows.height())
+        self.assertEqual(10, selected_rows.height())
 
         # compare tensor
         self.assertAlmostEqual(2.0,
-                               selcted_rows.get_tensor().get_float_element(0))
+                               selected_rows.get_tensor().get_float_element(0))
         self.assertAlmostEqual(1.0,
-                               selcted_rows.get_tensor().get_float_element(1))
+                               selected_rows.get_tensor().get_float_element(1))
         self.assertAlmostEqual(
-            4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8))
+            4.0,
+            selected_rows.get_tensor().get_float_element(2 * row_numel + 8))
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py
index 2dd881e5e107249277a91bd8e3a72567269e1cd4..01262bba4d43adaed179baef88ccab6e69b0884b 100644
--- a/python/paddle/v2/framework/tests/test_sgd_op.py
+++ b/python/paddle/v2/framework/tests/test_sgd_op.py
@@ -1,5 +1,7 @@
 import unittest
 import numpy as np
+import paddle.v2.framework.core as core
+from paddle.v2.framework.op import Operator
 from op_test import OpTest
 
 
@@ -17,5 +19,70 @@ class TestSGDOp(OpTest):
         self.check_output()
 
 
+class TestSparseSGDOp(unittest.TestCase):
+    def check_with_place(self, place):
+        scope = core.Scope()
+
+        # create and initialize Grad Variable   
+        height = 10
+        rows = [0, 4, 7]
+        row_numel = 12
+
+        grad_selected_rows = scope.var('Grad').get_selected_rows()
+        grad_selected_rows.set_height(height)
+        grad_selected_rows.set_rows(rows)
+        np_array = np.ones((len(rows), row_numel)).astype("float32")
+        np_array[0, 0] = 2.0
+        np_array[2, 8] = 4.0
+
+        grad_tensor = grad_selected_rows.get_tensor()
+        grad_tensor.set(np_array, place)
+
+        # create and initialize Param Variable
+        param = scope.var('Param').get_tensor()
+        param_array = np.full((height, row_numel), 5.0).astype("float32")
+        param.set(param_array, place)
+
+        # create and initialize LeraningRate Variable
+        lr = scope.var('LearningRate').get_tensor()
+        lr_array = np.full((1), 2.0).astype("float32")
+        lr.set(lr_array, place)
+
+        # create and run sgd operator
+        sgd_op = Operator(
+            "sgd",
+            Param='Param',
+            Grad='Grad',
+            ParamOut='Param',
+            LearningRate='LearningRate')
+        ctx = core.DeviceContext.create(place)
+        sgd_op.run(scope, ctx)
+
+        # get and compare result
+        result_array = np.array(param)
+
+        # rows[0] = 0, 5.0 - 2.0 * 2.0
+        self.assertAlmostEqual(1.0, result_array[rows[0], 0])
+        # rows[0] = 0, 5.0 - 2.0 * 1.0
+        self.assertAlmostEqual(3.0, result_array[rows[0], 2])
+        # 5.0 - 2.0 * 0.0
+        self.assertAlmostEqual(5.0, result_array[1, 0])
+        # rows[1] = 4, 5.0 - 2.0 * 1.0
+        self.assertAlmostEqual(3.0, result_array[rows[1], 10])
+        # 5.0 - 2.0 * 0.0
+        self.assertAlmostEqual(5.0, result_array[5, 8])
+        # rows[2] = 7, 5.0 - 2.0 * 1.0
+        self.assertAlmostEqual(3.0, result_array[rows[2], 1])
+        # rows[2] = 7, 5.0 - 2.0 * 4.0
+        self.assertAlmostEqual(-3.0, result_array[rows[2], 8])
+
+    def test_sparse_sgd(self):
+        places = [core.CPUPlace()]
+        if core.is_compile_gpu():
+            places.append(core.GPUPlace(0))
+        for place in places:
+            self.check_with_place(place)
+
+
 if __name__ == "__main__":
     unittest.main()