diff --git a/paddle/fluid/eager/autograd_meta.h b/paddle/fluid/eager/autograd_meta.h index eb5a9b5850a68cdf0f7c86d5c49ae2b86e6d7ceb..7d87a7cbaafa8a3c1e4920bf49496ef8459d0806 100644 --- a/paddle/fluid/eager/autograd_meta.h +++ b/paddle/fluid/eager/autograd_meta.h @@ -123,12 +123,6 @@ class AutogradMeta : public AbstractAutogradMeta { stop_gradient_ = static_cast(stop_gradient); } - void WeakSetStopGradient(bool stop_gradient) { - if (stop_gradient_ == -1) { - stop_gradient_ = static_cast(stop_gradient); - } - } - bool Persistable() const { return persistable_; } void SetPersistable(bool persistable) { persistable_ = persistable; } diff --git a/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc index 391d218362b50ff522050656b8c955a794e0be84..bcdeb01e635e5ab9ccd95d7081318bc6fc5c858d 100644 --- a/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc +++ b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc @@ -163,7 +163,7 @@ TEST(EagerUtils, PassStopGradient) { auto_grad1.get(), auto_grad2.get(), auto_grad3.get()); - CHECK(auto_grad0->StopGradient() == false); + CHECK(auto_grad0->StopGradient() == true); CHECK(auto_grad1->StopGradient() == true); CHECK(auto_grad2->StopGradient() == true); CHECK(auto_grad3->StopGradient() == true); diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 783afcc1e2c73745dee60ece6a75e139e9218287..a42b118771830ce347655245dd8b2ff6aaf1f9ad 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -77,7 +77,7 @@ class PassStopGradientIter : public IterHelper { VLOG(2) << "Tensor is NULL"; return; } - element->WeakSetStopGradient(stop_gradient_); + element->SetStopGradient(stop_gradient_); } bool stop_gradient_ = true; diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc index 970530353c7927854f075c84c2891189ed6b33d6..b841afff1579f3db4001d041f02a0f0c92a47443 100644 --- a/paddle/fluid/pybind/eager_py_layer.cc +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -315,7 +315,7 @@ PyObject* pylayer_method_apply(PyObject* cls, non_differentiable.end()) { outputs_autograd_meta[i][j]->SetStopGradient(true); } else { - outputs_autograd_meta[i][j]->WeakSetStopGradient(false); + outputs_autograd_meta[i][j]->SetStopGradient(false); } } }