diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 62d449ccd2ea8c873629a6dade5fce2fac167aed..a31fd436c71644310b5be0ffc57c5c1a5132e009 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -569,6 +569,13 @@ void BasicEngine::Execute() { } } + // Function Post Hook + if (cur_op.HasVoidFunctionPostHook()) { + for (const auto& hook : cur_op.GetVoidFunctionPostHooks()) { + (*hook)(); + } + } + for (auto& pair : inplace_output_grad_var_list_) { *pair.first = std::move(*pair.second); } diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 4122e2af3dedaee0b0dfd74923870b7137fe73a3..3ff451f81787209eca36acdf9aac62519a90f399 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -186,6 +186,19 @@ class OpBase { static pten::KernelContext* GetKernelContext() { return &pt_kernel_context_; } + bool HasVoidFunctionPostHook() const { + return !void_function_post_hooks_.empty(); + } + + void AddVoidFunctionPostHook(std::shared_ptr>&& hook) { + void_function_post_hooks_.emplace_back(std::move(hook)); + } + + const std::vector>>& + GetVoidFunctionPostHooks() const { + return void_function_post_hooks_; + } + private: static const std::string& UnknownOpType() { static std::string kUnknownOpType{"unknown"}; @@ -203,6 +216,7 @@ class OpBase { // In order to reduce the compatibility phase // performance overhead, temporarily cache KernelContext static pten::KernelContext pt_kernel_context_; + std::vector>> void_function_post_hooks_; }; class GradOpNode { diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 445654efb54a676e132154b21e65704ce90d17e0..7f6e3644bc3293239915e2c9b34f5f6404b80ccc 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1640,6 +1640,26 @@ void BindImperative(py::module *m_ptr) { "gradient or without gradient.")); return self.GradVarBase()->RemoveVariableWrapperHook(hook_id); }) + .def("_register_void_function_post_hook", + [](imperative::VarBase &self, const py::handle &hook) { + PADDLE_ENFORCE_EQ( + !self.OverridedStopGradient() && self.HasGradVar(), true, + platform::errors::InvalidArgument( + "Cannot register void function post hook on a Tensor that " + "stop " + "gradient or without gradient.")); + auto py_func = PyObjectCast>(hook.ptr()); + VLOG(1) << 111; + auto grad_node = self.MutableGradVarBase()->GradNode(); + VLOG(1) << 222; + VLOG(1) << (grad_node == nullptr); + for (auto &cur_op : *grad_node) { + VLOG(1) << 333; + cur_op.AddVoidFunctionPostHook( + std::make_shared>(py_func)); + VLOG(1) << 444; + } + }) .def("_register_backward_hook", [](imperative::VarBase &self, const py::handle &hook) { PADDLE_ENFORCE_EQ( diff --git a/python/paddle/fluid/tests/unittests/test_function_hook.py b/python/paddle/fluid/tests/unittests/test_function_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..d45ef528261f394ff54e94714aa40d345b9aa458 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_function_hook.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np + +import paddle.fluid.core as core +from paddle import _C_ops + + +class TestCapture: + def __init__(self): + self.list = [] + + +test_cap = TestCapture() + + +def test_hook(): + test_cap.list.append(1) + + +def grad_hook(grad): + test_cap.list.append(2) + + return grad + + +class TestBakcwardFunctionHookError(unittest.TestCase): + def test_hook(self): + input_data = np.ones([4, 4]).astype('float32') + + x = paddle.to_tensor(input_data.astype(np.float32), stop_gradient=False) + z = paddle.to_tensor(input_data.astype(np.float32), stop_gradient=False) + + y = _C_ops.sigmoid(x) + out = _C_ops.matmul_v2(y, z, 'trans_x', False, 'trans_y', False) + + out._register_void_function_post_hook(test_hook) + y._register_void_function_post_hook(test_hook) + y.register_hook(grad_hook) + + out.backward() + + assert test_cap.list == [1, 2, 1] + + +if __name__ == "__main__": + unittest.main()