From cccba68c63b76bef6b00f3ff6f09c9634fb60527 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 17 Aug 2022 14:18:45 +0800 Subject: [PATCH] [Eager]fix_stop_gradient (#45154) * fix_stop_gradient --- paddle/fluid/eager/autograd_meta.h | 6 ------ paddle/fluid/eager/tests/task_tests/eager_utils_test.cc | 2 +- paddle/fluid/eager/utils.h | 2 +- paddle/fluid/pybind/eager_py_layer.cc | 2 +- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/eager/autograd_meta.h b/paddle/fluid/eager/autograd_meta.h index eb5a9b5850a..7d87a7cbaaf 100644 --- a/paddle/fluid/eager/autograd_meta.h +++ b/paddle/fluid/eager/autograd_meta.h @@ -123,12 +123,6 @@ class AutogradMeta : public AbstractAutogradMeta { stop_gradient_ = static_cast(stop_gradient); } - void WeakSetStopGradient(bool stop_gradient) { - if (stop_gradient_ == -1) { - stop_gradient_ = static_cast(stop_gradient); - } - } - bool Persistable() const { return persistable_; } void SetPersistable(bool persistable) { persistable_ = persistable; } diff --git a/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc index 391d218362b..bcdeb01e635 100644 --- a/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc +++ b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc @@ -163,7 +163,7 @@ TEST(EagerUtils, PassStopGradient) { auto_grad1.get(), auto_grad2.get(), auto_grad3.get()); - CHECK(auto_grad0->StopGradient() == false); + CHECK(auto_grad0->StopGradient() == true); CHECK(auto_grad1->StopGradient() == true); CHECK(auto_grad2->StopGradient() == true); CHECK(auto_grad3->StopGradient() == true); diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 783afcc1e2c..a42b1187718 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -77,7 +77,7 @@ class PassStopGradientIter : public IterHelper { VLOG(2) << "Tensor is NULL"; return; } - element->WeakSetStopGradient(stop_gradient_); + element->SetStopGradient(stop_gradient_); } bool stop_gradient_ = true; diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc index 970530353c7..b841afff157 100644 --- a/paddle/fluid/pybind/eager_py_layer.cc +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -315,7 +315,7 @@ PyObject* pylayer_method_apply(PyObject* cls, non_differentiable.end()) { outputs_autograd_meta[i][j]->SetStopGradient(true); } else { - outputs_autograd_meta[i][j]->WeakSetStopGradient(false); + outputs_autograd_meta[i][j]->SetStopGradient(false); } } } -- GitLab