未验证 提交 69b79e6f 编写于 作者: P pangyoki 提交者: GitHub

ignore no_need_buffer tensor_wrapper in inplace checking (#41350)

* support inplace no_need_buffer

* fix

* use padle.add
上级 e90f9367
......@@ -51,6 +51,7 @@ class TensorWrapper {
* to avoid recursive depends on GradNodeBase
* **/
full_reserved_ = full_reserved;
no_need_buffer_ = no_need_buffer;
if (full_reserved_) {
VLOG(6) << "Fully reserved tensor: " << tensor.name();
intermidiate_tensor_ = tensor;
......@@ -58,7 +59,6 @@ class TensorWrapper {
}
// shallow copy tensor_impl here
no_need_buffer_ = no_need_buffer;
if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta
......
......@@ -103,7 +103,9 @@ class TestInplace(unittest.TestCase):
var_b[1:2] = 3 # var_b is modified inplace before using it
var_c = var_b + var_b # Here, the grad op of sum doesn't use the value of var_b
var_c = paddle.add(
var_b,
var_b) # Here, the grad op of sum doesn't use the value of var_b
loss = var_c.sum()
var_b[1:2] = 3 # var_b is modified inplace after using it
......@@ -111,9 +113,8 @@ class TestInplace(unittest.TestCase):
loss.backward()
def test_backward_success_2(self):
# TODO: need to process no_need_buffer in eager mode
# with _test_eager_guard():
# self.func_test_backward_success_2()
with _test_eager_guard():
self.func_test_backward_success_2()
self.func_test_backward_success_2()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册