未验证 提交 cccba68c 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager]fix_stop_gradient (#45154)

* fix_stop_gradient
上级 cc12d27f
......@@ -123,12 +123,6 @@ class AutogradMeta : public AbstractAutogradMeta {
stop_gradient_ = static_cast<int>(stop_gradient);
}
void WeakSetStopGradient(bool stop_gradient) {
if (stop_gradient_ == -1) {
stop_gradient_ = static_cast<int>(stop_gradient);
}
}
bool Persistable() const { return persistable_; }
void SetPersistable(bool persistable) { persistable_ = persistable; }
......
......@@ -163,7 +163,7 @@ TEST(EagerUtils, PassStopGradient) {
auto_grad1.get(),
auto_grad2.get(),
auto_grad3.get());
CHECK(auto_grad0->StopGradient() == false);
CHECK(auto_grad0->StopGradient() == true);
CHECK(auto_grad1->StopGradient() == true);
CHECK(auto_grad2->StopGradient() == true);
CHECK(auto_grad3->StopGradient() == true);
......
......@@ -77,7 +77,7 @@ class PassStopGradientIter : public IterHelper<AutogradMeta*> {
VLOG(2) << "Tensor is NULL";
return;
}
element->WeakSetStopGradient(stop_gradient_);
element->SetStopGradient(stop_gradient_);
}
bool stop_gradient_ = true;
......
......@@ -315,7 +315,7 @@ PyObject* pylayer_method_apply(PyObject* cls,
non_differentiable.end()) {
outputs_autograd_meta[i][j]->SetStopGradient(true);
} else {
outputs_autograd_meta[i][j]->WeakSetStopGradient(false);
outputs_autograd_meta[i][j]->SetStopGradient(false);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册