未验证 提交 ff7cbaae 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager Hook + Inplace] Refactor register_hook and test with inplace operation (#40778)

* disable scatter case in test_inplace_eager_fluid

* Update register_hook logic

* Add register_hook test cases
Co-authored-by: Npangyoki <pangyoki@126.com>
上级 fe291daf
......@@ -868,16 +868,22 @@ static PyObject* tensor_register_grad_hook(TensorObject* self, PyObject* args,
int64_t hook_id;
if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
if (autograd_meta && !autograd_meta->StopGradient()) {
if (!autograd_meta->GetMutableGradNode()) {
VLOG(6) << "Detected NULL grad_node, Leaf tensor should have had "
"grad_node with type: GradNodeAccumulation.";
autograd_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
}
}
std::shared_ptr<egr::GradNodeBase> grad_node =
egr::EagerUtils::grad_node(self->tensor);
PADDLE_ENFORCE(
grad_node.get() != nullptr,
paddle::platform::errors::Fatal("Detected NULL grad_node,"
"Leaf tensor should have had grad_node "
"with type: GradNodeAccumulation."));
auto rank_info =
egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
PyObject* hook_func = PyTuple_GET_ITEM(args, 0);
auto accumulation_grad_node =
......
......@@ -171,6 +171,180 @@ class TestDygraphInplace(unittest.TestCase):
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_3(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
helper = var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
# Here, the gradient computation will use the value of var_b
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
helper = var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.non_inplace_api_processing(var_b)
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_4(self):
# Although var_b is modified inplace after using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_a.register_hook(double_hook)
var_b = var_a**2
var_c = self.non_inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_5(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
# Here, the gradient computation will use the value of var_b
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.non_inplace_api_processing(var_b)
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
# inplace + hook
def test_backward_success_6(self):
# Although var_b is modified inplace before using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
def double_hook(grad):
grad = grad * 2
return grad
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(
self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_b.register_hook(double_hook)
var_c = self.non_inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
class TestDygraphInplaceUnsqueeze(TestDygraphInplace):
def non_inplace_api_processing(self, var):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册