You need to sign in or sign up before continuing.
提交 a240bce1 编写于 作者: Q qiaolongfei

fix backward

上级 5b7633a5
......@@ -79,9 +79,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.Output() /*names*/, kGradVarSuffix /*suffix*/,
if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) {
ForEachVarName(forwardOp.inputs_,
ForEachVarName(forwardOp.Inputs(),
[&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name));
return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册