未验证 提交 cb8afc24 编写于 作者: P pangyoki 提交者: GitHub

add _reset_grad_inplace_version (#41101)

上级 a5bfa797
...@@ -1308,6 +1308,27 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args, ...@@ -1308,6 +1308,27 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
Py_ssize_t args_num = PyTuple_Size(args);
bool set_to_zero = true;
if (args_num == (Py_ssize_t)1) {
set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
}
paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
if (grad && grad->defined() && grad->is_dense_tensor() &&
grad->initialized()) {
grad->reset_inplace_version(set_to_zero);
}
Py_INCREF(Py_None);
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyMethodDef variable_methods[] = { PyMethodDef variable_methods[] = {
{"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy, {"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
...@@ -1407,6 +1428,9 @@ PyMethodDef variable_methods[] = { ...@@ -1407,6 +1428,9 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows, {"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"_reset_grad_inplace_version",
(PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL}}; {NULL, NULL, 0, NULL}};
} // namespace pybind } // namespace pybind
......
...@@ -516,6 +516,11 @@ class PADDLE_API Tensor final { ...@@ -516,6 +516,11 @@ class PADDLE_API Tensor final {
*/ */
uint32_t current_inplace_version(); uint32_t current_inplace_version();
/**
* @brief Reset inplace version
*/
void reset_inplace_version(bool set_to_zero = false);
/* Part 10: Auto generated Tensor methods */ /* Part 10: Auto generated Tensor methods */
/* Part 11: Methods of converting SparseTensor and DenseTensor to each other /* Part 11: Methods of converting SparseTensor and DenseTensor to each other
......
...@@ -384,5 +384,16 @@ uint32_t Tensor::current_inplace_version() { ...@@ -384,5 +384,16 @@ uint32_t Tensor::current_inplace_version() {
return 0; return 0;
} }
void Tensor::reset_inplace_version(bool set_to_zero) {
if (set_to_zero) {
if (is_dense_tensor()) {
auto &inplace_version_counter =
std::dynamic_pointer_cast<phi::DenseTensor>(impl_)
->InplaceVersionCounter();
inplace_version_counter.SetInplaceVersionToZero();
}
}
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ import paddle ...@@ -16,6 +16,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.framework import _test_eager_guard
import unittest import unittest
paddle.set_device('cpu') paddle.set_device('cpu')
...@@ -32,7 +33,7 @@ def clear_grad_test_0(w, a): ...@@ -32,7 +33,7 @@ def clear_grad_test_0(w, a):
class TestInplaceAndClearGradient(unittest.TestCase): class TestInplaceAndClearGradient(unittest.TestCase):
def test(self): def func_test(self):
input_data = np.ones([1, 1]) input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
...@@ -45,6 +46,11 @@ class TestInplaceAndClearGradient(unittest.TestCase): ...@@ -45,6 +46,11 @@ class TestInplaceAndClearGradient(unittest.TestCase):
out.backward() out.backward()
assert w.grad[0] == 0.15 assert w.grad[0] == 0.15
def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()
# Test 2 # Test 2
class Counter: class Counter:
...@@ -67,7 +73,7 @@ def clear_grad_test_1(w, c): ...@@ -67,7 +73,7 @@ def clear_grad_test_1(w, c):
class TestInplaceClearGradAccumulation(unittest.TestCase): class TestInplaceClearGradAccumulation(unittest.TestCase):
def test(self): def func_test(self):
input_data = np.ones([1, 1]) input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
c = Counter() c = Counter()
...@@ -87,9 +93,14 @@ class TestInplaceClearGradAccumulation(unittest.TestCase): ...@@ -87,9 +93,14 @@ class TestInplaceClearGradAccumulation(unittest.TestCase):
assert c.num_calls == 1 assert c.num_calls == 1
c.num_calls = 0 c.num_calls = 0
def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()
class TestInplaceClearGradAccumulationAlt(unittest.TestCase): class TestInplaceClearGradAccumulationAlt(unittest.TestCase):
def test(self): def func_test(self):
input_data = np.ones([1, 1]) input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
out = _C_ops.scale(w, 'scale', 0.1) out = _C_ops.scale(w, 'scale', 0.1)
...@@ -100,6 +111,11 @@ class TestInplaceClearGradAccumulationAlt(unittest.TestCase): ...@@ -100,6 +111,11 @@ class TestInplaceClearGradAccumulationAlt(unittest.TestCase):
assert w.grad._inplace_version() == 1 assert w.grad._inplace_version() == 1
def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册