未验证 提交 cccba68c 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager]fix_stop_gradient (#45154)

* fix_stop_gradient
上级 cc12d27f
...@@ -123,12 +123,6 @@ class AutogradMeta : public AbstractAutogradMeta { ...@@ -123,12 +123,6 @@ class AutogradMeta : public AbstractAutogradMeta {
stop_gradient_ = static_cast<int>(stop_gradient); stop_gradient_ = static_cast<int>(stop_gradient);
} }
void WeakSetStopGradient(bool stop_gradient) {
if (stop_gradient_ == -1) {
stop_gradient_ = static_cast<int>(stop_gradient);
}
}
bool Persistable() const { return persistable_; } bool Persistable() const { return persistable_; }
void SetPersistable(bool persistable) { persistable_ = persistable; } void SetPersistable(bool persistable) { persistable_ = persistable; }
......
...@@ -163,7 +163,7 @@ TEST(EagerUtils, PassStopGradient) { ...@@ -163,7 +163,7 @@ TEST(EagerUtils, PassStopGradient) {
auto_grad1.get(), auto_grad1.get(),
auto_grad2.get(), auto_grad2.get(),
auto_grad3.get()); auto_grad3.get());
CHECK(auto_grad0->StopGradient() == false); CHECK(auto_grad0->StopGradient() == true);
CHECK(auto_grad1->StopGradient() == true); CHECK(auto_grad1->StopGradient() == true);
CHECK(auto_grad2->StopGradient() == true); CHECK(auto_grad2->StopGradient() == true);
CHECK(auto_grad3->StopGradient() == true); CHECK(auto_grad3->StopGradient() == true);
......
...@@ -77,7 +77,7 @@ class PassStopGradientIter : public IterHelper<AutogradMeta*> { ...@@ -77,7 +77,7 @@ class PassStopGradientIter : public IterHelper<AutogradMeta*> {
VLOG(2) << "Tensor is NULL"; VLOG(2) << "Tensor is NULL";
return; return;
} }
element->WeakSetStopGradient(stop_gradient_); element->SetStopGradient(stop_gradient_);
} }
bool stop_gradient_ = true; bool stop_gradient_ = true;
......
...@@ -315,7 +315,7 @@ PyObject* pylayer_method_apply(PyObject* cls, ...@@ -315,7 +315,7 @@ PyObject* pylayer_method_apply(PyObject* cls,
non_differentiable.end()) { non_differentiable.end()) {
outputs_autograd_meta[i][j]->SetStopGradient(true); outputs_autograd_meta[i][j]->SetStopGradient(true);
} else { } else {
outputs_autograd_meta[i][j]->WeakSetStopGradient(false); outputs_autograd_meta[i][j]->SetStopGradient(false);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册