From 69b79e6f09a954b4cd6bc3b0d16f03534db24134 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 5 Apr 2022 08:23:18 +0800 Subject: [PATCH] ignore no_need_buffer tensor_wrapper in inplace checking (#41350) * support inplace no_need_buffer * fix * use padle.add --- paddle/fluid/eager/tensor_wrapper.h | 2 +- python/paddle/fluid/tests/unittests/test_inplace.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index dc4cf379390..3d5d3139de1 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -51,6 +51,7 @@ class TensorWrapper { * to avoid recursive depends on GradNodeBase * **/ full_reserved_ = full_reserved; + no_need_buffer_ = no_need_buffer; if (full_reserved_) { VLOG(6) << "Fully reserved tensor: " << tensor.name(); intermidiate_tensor_ = tensor; @@ -58,7 +59,6 @@ class TensorWrapper { } // shallow copy tensor_impl here - no_need_buffer_ = no_need_buffer; if (no_need_buffer) { if (phi::DenseTensor::classof(tensor.impl().get())) { // Only Copy Meta diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index b4f1dc22f4e..ee0d5bcdde6 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -103,7 +103,9 @@ class TestInplace(unittest.TestCase): var_b[1:2] = 3 # var_b is modified inplace before using it - var_c = var_b + var_b # Here, the grad op of sum doesn't use the value of var_b + var_c = paddle.add( + var_b, + var_b) # Here, the grad op of sum doesn't use the value of var_b loss = var_c.sum() var_b[1:2] = 3 # var_b is modified inplace after using it @@ -111,9 +113,8 @@ class TestInplace(unittest.TestCase): loss.backward() def test_backward_success_2(self): - # TODO: need to process no_need_buffer in eager mode - # with _test_eager_guard(): - # self.func_test_backward_success_2() + with _test_eager_guard(): + self.func_test_backward_success_2() self.func_test_backward_success_2() -- GitLab