未验证 提交 c9fc7ba9 编写于 作者: Y Yang Yu

Do not sum output if that output is not a gradient

* increament is default inplace
上级 6d41bfb7
...@@ -408,6 +408,11 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -408,6 +408,11 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
for (const auto& desc : op_grads) { for (const auto& desc : op_grads) {
for (const std::string& out_name : desc->OutputArgumentNames()) { for (const std::string& out_name : desc->OutputArgumentNames()) {
if (out_name.find("@GRAD") == std::string::npos) {
// Not all outputs of a backward operator is a gradient. Only gradient
// need to be sum. Skip variables are not gradient.
continue;
}
dup_out_ops[out_name].emplace_back(grad_desc_idx); dup_out_ops[out_name].emplace_back(grad_desc_idx);
} }
++grad_desc_idx; ++grad_desc_idx;
......
...@@ -823,7 +823,7 @@ def zeros(shape, dtype, main_program=None): ...@@ -823,7 +823,7 @@ def zeros(shape, dtype, main_program=None):
return fill_constant(value=0.0, **locals()) return fill_constant(value=0.0, **locals())
def increment(x, value=1.0, in_place=False, main_program=None): def increment(x, value=1.0, in_place=True, main_program=None):
helper = LayerHelper("increment", **locals()) helper = LayerHelper("increment", **locals())
if in_place: if in_place:
tmp = x tmp = x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册