未验证 提交 acde295c 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] refactor general_grad and fix some bugs (#44611)

* refactor general_grad and fix some bugs

* add TODO: support prune logic deeper
上级 d4cf02bc
此差异已折叠。
此差异已折叠。
......@@ -253,6 +253,19 @@ class GradNodeBase {
* **/
inline bool GradientHooksRegistered() { return !gradient_hooks_.empty(); }
std::map<int64_t, std::tuple<size_t, size_t, std::shared_ptr<TensorHook>>>
GetGradientHookFuntions() {
VLOG(6) << "GetGradientHookFuntions ";
return gradient_hooks_;
}
void SetGradientHookFuntions(
std::map<int64_t, std::tuple<size_t, size_t, std::shared_ptr<TensorHook>>>
hooks) {
VLOG(6) << "SetGradientHookFuntions ";
gradient_hooks_ = hooks;
}
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>
ApplyGradientHooks(
......
......@@ -166,6 +166,46 @@ class TestEagerGrad(TestCase):
self.func_simple_example_eager_grad_duplicate_output()
self.func_simple_example_eager_grad_duplicate_output()
def test_simple_example_eager_two_grad_output(self):
with _test_eager_guard():
x1 = paddle.to_tensor([1.0, 2.0])
x1.stop_gradient = False
x2 = paddle.to_tensor([1.0, 2.0])
x2.stop_gradient = False
out1 = x1 * 2
out2 = x2 * 2
dout2_record_by_hook = []
def record_hook(grad):
dout2_record_by_hook.append(grad)
out2.register_hook(record_hook)
out3 = paddle.multiply(out1, out2)
out4 = paddle.mean(out3)
egr_dout2, egr_dout3 = paddle.grad([out4], [out2, out3])
self.assertTrue(
np.array_equal(dout2_record_by_hook[0].numpy(),
np.array([1., 2.])))
x1 = paddle.to_tensor([1.0, 2.0])
x1.stop_gradient = False
x2 = paddle.to_tensor([1.0, 2.0])
x2.stop_gradient = False
out1 = x1 * 2
out2 = x2 * 2
out3 = paddle.multiply(out1, out2)
out4 = paddle.mean(out3)
dout2, dout3 = paddle.grad([out4], [out2, out3])
self.assertEqual(dout2.stop_gradient, egr_dout2.stop_gradient)
self.assertEqual(dout3.stop_gradient, egr_dout3.stop_gradient)
self.assertTrue(np.array_equal(dout2.numpy(), egr_dout2.numpy()))
self.assertTrue(np.array_equal(dout3.numpy(), egr_dout3.numpy()))
class TestDygraphDoubleGrad(TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册