未验证 提交 9b9a225b 编写于 作者: P pangyoki 提交者: GitHub

Cherry-pick PR 38406, fix accumulator bug when multiple inplace OPs are...

Cherry-pick PR 38406, fix accumulator bug when multiple inplace OPs are executed continuously  (#38406) (#38830)

Cherry pick PR #38406
上级 27774eec
......@@ -174,7 +174,7 @@ void BasicEngine::PrepareGradAccumulators(
if (!var) continue;
bool find_grad_node_of_var = false;
if (var->HasGradNode()) {
if (grad_pending_nodes.size()) {
// Because Inplace op overwrites the grad_node of the input grad_var. So
// only the information of grad_pending_node can be used to find the
// grad_node of grad_var.
......@@ -240,7 +240,7 @@ void BasicEngine::PrepareGradAccumulators(
}
}
if (!var->HasGradNode() || !find_grad_node_of_var) {
if (!grad_pending_nodes.size() || !find_grad_node_of_var) {
auto& accumulator = accumulators_[var.get()];
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
......@@ -438,15 +438,15 @@ void BasicEngine::Execute() {
continue;
}
const auto& grad_pending_nodes = shared_cur_node->GradPendingNodes();
std::unordered_map<VariableWrapper*,
std::unique_ptr<GradientAccumulator>>::iterator
iter;
bool flag_find_grad = false;
if (var->HasGradNode()) {
if (grad_pending_nodes.size()) {
VLOG(10) << "Find gradient of var (" << var->Name()
<< ") with grad_node.";
for (auto& grad_pending_node :
shared_cur_node->GradPendingNodes()) {
for (auto& grad_pending_node : grad_pending_nodes) {
const auto& iter_grad_node =
accumulators_with_grad_node_.find(grad_pending_node);
if (iter_grad_node != accumulators_with_grad_node_.end()) {
......@@ -458,10 +458,11 @@ void BasicEngine::Execute() {
}
}
if (!flag_find_grad) {
VLOG(6) << "Cannot find gradient of variable " << var->Name();
VLOG(6) << "Cannot find gradient of variable " << var->Name()
<< " in accumulators_with_grad_node_";
}
}
if (!var->HasGradNode() || !flag_find_grad) {
if (!grad_pending_nodes.size() || !flag_find_grad) {
VLOG(10) << "Find gradient of var (" << var->Name()
<< ") with no grad_node.";
iter = accumulators_.find(var.get());
......
......@@ -434,5 +434,18 @@ class TestLossIsInplaceVar(unittest.TestCase):
self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))
class TestContinuouslyInplace(unittest.TestCase):
def test_continuously_inplace(self):
a = paddle.rand([2, 3])
a.stop_gradient = False
b = a * 2
b.reshape_([-1])
b.reshape_([2, 3])
b.reshape_([-1])
b.backward()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册