提交 a240bce1 编写于 作者: Q qiaolongfei

fix backward

上级 5b7633a5
...@@ -79,9 +79,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -79,9 +79,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.Output() /*names*/, kGradVarSuffix /*suffix*/, if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) { no_grad_names /*set*/)) {
ForEachVarName(forwardOp.inputs_, ForEachVarName(forwardOp.Inputs(),
[&no_grad_names](const std::string& name) -> bool { [&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name)); no_grad_names.insert(GradVarName(name));
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册