未验证 提交 31344ab7 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Add backward function hook to dygraph (#37141)

上级 21957476
......@@ -569,6 +569,13 @@ void BasicEngine::Execute() {
}
}
// Function Post Hook
if (cur_op.HasVoidFunctionPostHook()) {
for (const auto& hook : cur_op.GetVoidFunctionPostHooks()) {
(*hook)();
}
}
for (auto& pair : inplace_output_grad_var_list_) {
*pair.first = std::move(*pair.second);
}
......
......@@ -186,6 +186,19 @@ class OpBase {
static pten::KernelContext* GetKernelContext() { return &pt_kernel_context_; }
bool HasVoidFunctionPostHook() const {
return !void_function_post_hooks_.empty();
}
void AddVoidFunctionPostHook(std::shared_ptr<std::function<void()>>&& hook) {
void_function_post_hooks_.emplace_back(std::move(hook));
}
const std::vector<std::shared_ptr<std::function<void()>>>&
GetVoidFunctionPostHooks() const {
return void_function_post_hooks_;
}
private:
static const std::string& UnknownOpType() {
static std::string kUnknownOpType{"unknown"};
......@@ -203,6 +216,7 @@ class OpBase {
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
static pten::KernelContext pt_kernel_context_;
std::vector<std::shared_ptr<std::function<void()>>> void_function_post_hooks_;
};
class GradOpNode {
......
......@@ -1640,6 +1640,26 @@ void BindImperative(py::module *m_ptr) {
"gradient or without gradient."));
return self.GradVarBase()->RemoveVariableWrapperHook(hook_id);
})
.def("_register_void_function_post_hook",
[](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ(
!self.OverridedStopGradient() && self.HasGradVar(), true,
platform::errors::InvalidArgument(
"Cannot register void function post hook on a Tensor that "
"stop "
"gradient or without gradient."));
auto py_func = PyObjectCast<std::function<void()>>(hook.ptr());
VLOG(1) << 111;
auto grad_node = self.MutableGradVarBase()->GradNode();
VLOG(1) << 222;
VLOG(1) << (grad_node == nullptr);
for (auto &cur_op : *grad_node) {
VLOG(1) << 333;
cur_op.AddVoidFunctionPostHook(
std::make_shared<std::function<void()>>(py_func));
VLOG(1) << 444;
}
})
.def("_register_backward_hook",
[](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import numpy as np
import paddle.fluid.core as core
from paddle import _C_ops
class TestCapture:
def __init__(self):
self.list = []
test_cap = TestCapture()
def test_hook():
test_cap.list.append(1)
def grad_hook(grad):
test_cap.list.append(2)
return grad
class TestBakcwardFunctionHookError(unittest.TestCase):
def test_hook(self):
input_data = np.ones([4, 4]).astype('float32')
x = paddle.to_tensor(input_data.astype(np.float32), stop_gradient=False)
z = paddle.to_tensor(input_data.astype(np.float32), stop_gradient=False)
y = _C_ops.sigmoid(x)
out = _C_ops.matmul_v2(y, z, 'trans_x', False, 'trans_y', False)
out._register_void_function_post_hook(test_hook)
y._register_void_function_post_hook(test_hook)
y.register_hook(grad_hook)
out.backward()
assert test_cap.list == [1, 2, 1]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册