未验证 提交 69b79e6f 编写于 作者: P pangyoki 提交者: GitHub

ignore no_need_buffer tensor_wrapper in inplace checking (#41350)

* support inplace no_need_buffer

* fix

* use padle.add
上级 e90f9367
...@@ -51,6 +51,7 @@ class TensorWrapper { ...@@ -51,6 +51,7 @@ class TensorWrapper {
* to avoid recursive depends on GradNodeBase * to avoid recursive depends on GradNodeBase
* **/ * **/
full_reserved_ = full_reserved; full_reserved_ = full_reserved;
no_need_buffer_ = no_need_buffer;
if (full_reserved_) { if (full_reserved_) {
VLOG(6) << "Fully reserved tensor: " << tensor.name(); VLOG(6) << "Fully reserved tensor: " << tensor.name();
intermidiate_tensor_ = tensor; intermidiate_tensor_ = tensor;
...@@ -58,7 +59,6 @@ class TensorWrapper { ...@@ -58,7 +59,6 @@ class TensorWrapper {
} }
// shallow copy tensor_impl here // shallow copy tensor_impl here
no_need_buffer_ = no_need_buffer;
if (no_need_buffer) { if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) { if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta // Only Copy Meta
......
...@@ -103,7 +103,9 @@ class TestInplace(unittest.TestCase): ...@@ -103,7 +103,9 @@ class TestInplace(unittest.TestCase):
var_b[1:2] = 3 # var_b is modified inplace before using it var_b[1:2] = 3 # var_b is modified inplace before using it
var_c = var_b + var_b # Here, the grad op of sum doesn't use the value of var_b var_c = paddle.add(
var_b,
var_b) # Here, the grad op of sum doesn't use the value of var_b
loss = var_c.sum() loss = var_c.sum()
var_b[1:2] = 3 # var_b is modified inplace after using it var_b[1:2] = 3 # var_b is modified inplace after using it
...@@ -111,9 +113,8 @@ class TestInplace(unittest.TestCase): ...@@ -111,9 +113,8 @@ class TestInplace(unittest.TestCase):
loss.backward() loss.backward()
def test_backward_success_2(self): def test_backward_success_2(self):
# TODO: need to process no_need_buffer in eager mode with _test_eager_guard():
# with _test_eager_guard(): self.func_test_backward_success_2()
# self.func_test_backward_success_2()
self.func_test_backward_success_2() self.func_test_backward_success_2()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册