diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc
index 105d61015fcb9d5c4a8f97ebe69dcf69a0510ab5..96a132ac6abc21bef20f0c688f72c03ffebe47bd 100644
--- a/paddle/fluid/operators/set_value_op.cc
+++ b/paddle/fluid/operators/set_value_op.cc
@@ -146,22 +146,75 @@ Assignment to a Tensor in static mode.
 )DOC");
   }
 };
+
+template <typename T>
+class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
+ public:
+  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
+
+ protected:
+  void Apply(GradOpPtr<T> op) const override {
+    if (this->HasInput("ValueTensor")) {
+      op->SetType("slice");
+      op->SetInput("Input", this->OutputGrad("Out"));
+      if (this->HasInput("StartsTensorList")) {
+        op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
+      }
+      if (this->HasInput("EndsTensorList")) {
+        op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
+      }
+
+      // convert std::vector<int64_t > to std::vector<int >
+      std::vector<int64_t> axes_int64 = static_cast<std::vector<int64_t>>(
+          BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("axes")));
+      std::vector<int64_t> starts_int64 = static_cast<std::vector<int64_t>>(
+          BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("starts")));
+      std::vector<int64_t> ends_int64 = static_cast<std::vector<int64_t>>(
+          BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("ends")));
+      std::vector<int64_t> decrease_axes_int64 =
+          static_cast<std::vector<int64_t>>(BOOST_GET_CONST(
+              std::vector<int64_t>, this->GetAttr("decrease_axes")));
+
+      std::vector<int> axes(axes_int64.begin(), axes_int64.end());
+      std::vector<int> starts(starts_int64.begin(), starts_int64.end());
+      std::vector<int> ends(ends_int64.begin(), ends_int64.end());
+      std::vector<int> decrease_axes(decrease_axes_int64.begin(),
+                                     decrease_axes_int64.end());
+
+      op->SetAttr("axes", axes);
+      op->SetAttr("starts", starts);
+      op->SetAttr("ends", ends);
+      op->SetAttr("decrease_axis", decrease_axes);
+      op->SetAttr("infer_flags", std::vector<int>({}));
+
+      op->SetOutput("Out", this->InputGrad("ValueTensor"));
+    } else {
+      op->SetType("assign");
+      op->SetInput("X", this->OutputGrad("Out"));
+      op->SetOutput("Out", this->InputGrad("Input"));
+    }
+  }
+};
+
+DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
+
 }  // namespace operators
 }  // namespace paddle
 
 namespace ops = paddle::operators;
+namespace plat = paddle::platform;
 
-REGISTER_OPERATOR(
-    set_value, ops::SetValue, ops::SetValueMaker,
-    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
-    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
+REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
+                  ops::SetValueGradMaker<paddle::framework::OpDesc>,
+                  ops::SetValueGradMaker<paddle::imperative::OpBase>,
+                  ops::SetValueOpInplaceInferer);
 
 REGISTER_OP_CPU_KERNEL(
     set_value, ops::SetValueKernel<paddle::platform::CPUDeviceContext, int>,
-    ops::SetValueKernel<paddle::platform::CPUDeviceContext, int64_t>,
-    ops::SetValueKernel<paddle::platform::CPUDeviceContext, float>,
-    ops::SetValueKernel<paddle::platform::CPUDeviceContext, double>,
-    ops::SetValueKernel<paddle::platform::CPUDeviceContext, bool>);
+    ops::SetValueKernel<plat::CPUDeviceContext, int64_t>,
+    ops::SetValueKernel<plat::CPUDeviceContext, float>,
+    ops::SetValueKernel<plat::CPUDeviceContext, double>,
+    ops::SetValueKernel<plat::CPUDeviceContext, bool>);
 
 REGISTER_OP_VERSION(set_value)
     .AddCheckpoint(
diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc
index 0817dc336716211e3a538195f81d60ef3e0c5d08..ace62d210b368d1a20fa5aa890f23cd78162bdf7 100644
--- a/paddle/fluid/pybind/imperative.cc
+++ b/paddle/fluid/pybind/imperative.cc
@@ -718,7 +718,8 @@ void BindImperative(py::module *m_ptr) {
                {
                  // Release gil and do tracing
                  py::gil_scoped_release release;
-                 tracer->TraceOp("set_value", ins, outs, std::move(attrs));
+                 tracer->TraceOp("set_value", ins, outs, std::move(attrs),
+                                 {{"Input", "Out"}});
                }
              } else {
                auto self_numpy = TensorToPyArray(*self_tensor);
diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py
index 0885891cdbe02747c8babbdbe748f21b30c34598..9534e4fe9541663c204c10a8bfab1b0696cbaac5 100644
--- a/python/paddle/fluid/tests/unittests/test_set_value_op.py
+++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py
@@ -775,5 +775,76 @@ class TestError(TestSetValueBase):
         self._broadcast_mismatch()
 
 
+# 5. Test backward
+
+
+class Model(paddle.nn.Layer):
+    def __init__(self):
+        super(Model, self).__init__()
+        self.conv = paddle.nn.Conv2D(12, 12, 3)
+
+    def forward(self, x, y):
+        x = self.conv(x)
+        y = self.conv(y)
+        var = y.flatten()
+
+        x[0, :, 0, 0] = var
+        loss = paddle.mean(x)
+        return loss, var, x
+
+
+class TestBackward(unittest.TestCase):
+    def test_static(self):
+        paddle.enable_static()
+        main_program = paddle.static.Program()
+        startup_program = paddle.static.Program()
+
+        x_np = np.random.random(size=(4, 4)).astype('float32')
+        y_np = np.random.random(size=(4, 4)).astype('float32')
+        label_np = np.random.randint(2, size=(4, 1)).astype('int64')
+
+        with paddle.static.program_guard(main_program, startup_program):
+            x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
+            y = paddle.static.data(name="y", shape=[4, 4], dtype='float32')
+
+            label = paddle.static.data(
+                name="label", shape=[4, 1], dtype='int64')
+
+            z = paddle.add(x, y)
+            var = y[0, :]
+            z[0, :] = var
+
+            prediction = paddle.static.nn.fc(x=z, size=2, activation='softmax')
+
+            cost = paddle.nn.functional.cross_entropy(
+                input=prediction, label=label)
+            loss = paddle.mean(cost)
+            sgd = paddle.optimizer.SGD(learning_rate=0.01)
+            sgd.minimize(loss)
+
+        exe = paddle.static.Executor(paddle.CPUPlace())
+        exe.run(startup_program)
+
+        var_grad, z_grad = exe.run(
+            main_program,
+            feed={"x": x_np,
+                  "y": y_np,
+                  "label": label_np},
+            fetch_list=[var.name + "@GRAD", z.name + "@GRAD"])
+
+        self.assertTrue((var_grad == z_grad[0, :]).all())
+
+    def test_dynamic(self):
+        paddle.disable_static()
+        model = Model()
+        x = paddle.ones([1, 12, 3, 3]).astype("float32")
+        y = paddle.ones([1, 12, 3, 3]).astype("float32")
+        loss, var, x = model(x, y)
+        loss.backward()
+
+        self.assertTrue(var.grad.shape == x.grad[0, :, 0, 0].shape)
+        self.assertTrue((var.grad == x.grad[0, :, 0, 0]).all())
+
+
 if __name__ == '__main__':
     unittest.main()