未验证 提交 8c43c0fe 编写于 作者: J Jiabin Yang 提交者: GitHub

Support backward final hook (#44686)

上级 b7496bcb
......@@ -321,7 +321,7 @@ EagerReducer::EagerReducer(
const auto &accumulation_grad_node =
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook));
std::make_shared<egr::CppVoidHook>(reduce_hook));
gradnode_index_map_[grad_node.get()] = global_var_index;
}
......
......@@ -127,7 +127,7 @@ GradNodeAccumulation::operator()(
}
void GradNodeAccumulation::RegisterReduceHook(
std::shared_ptr<TensorVoidHook>&& hook) {
std::shared_ptr<VoidHook>&& hook) {
reduce_hooks_.emplace_back(std::move(hook));
}
......
......@@ -51,7 +51,7 @@ class GradNodeAccumulation : public GradNodeBase {
/**
* Register ReduceHook
* **/
void RegisterReduceHook(std::shared_ptr<TensorVoidHook>&& hook);
void RegisterReduceHook(std::shared_ptr<VoidHook>&& hook);
/**
* Apply ReduceHook here
......@@ -70,7 +70,7 @@ class GradNodeAccumulation : public GradNodeBase {
// TODO(Jiabin): remove this when we make our clear gradient really cleared;
bool is_fake_empty_ = {false};
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
std::vector<std::shared_ptr<VoidHook>> reduce_hooks_;
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
retain_grad_hook_;
......
......@@ -18,11 +18,11 @@
#include <atomic>
#include <memory>
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/type_defs.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/utils/small_vector.h"
namespace egr {
class UniqueNameGenerator {
public:
......@@ -85,6 +85,22 @@ class Controller {
GetCustomEdgesSlotMap() {
return custom_edges_slot_map_;
}
// For Cpp Hook
void RegisterBackwardFinalHook(const std::function<void()>& call_back) {
VLOG(6) << "RegisterBackwardFinalHook";
final_backward_hooks_.emplace_back(
std::make_shared<CppVoidHook>(std::move(call_back)));
VLOG(6) << "Size: " << final_backward_hooks_.size();
}
// For Python hook
void RegisterBackwardFinalHook(const std::shared_ptr<VoidHook>& call_back) {
final_backward_hooks_.emplace_back(call_back);
}
const std::vector<std::shared_ptr<VoidHook>>& FinalBackwardHooks() const {
return final_backward_hooks_;
}
void ClearFinalBackwardHooks() { final_backward_hooks_.clear(); }
private:
Controller() = default;
......@@ -98,6 +114,7 @@ class Controller {
std::unordered_map<std::string,
std::vector<std::vector<std::unordered_map<int, int>>>>
custom_edges_slot_map_;
std::vector<std::shared_ptr<VoidHook>> final_backward_hooks_;
DISABLE_COPY_AND_ASSIGN(Controller);
};
......
......@@ -25,17 +25,20 @@ namespace egr_utils_api {
int64_t RegisterGradientHookForTensor(
const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorHook>&& hook) {
const std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& hook) {
// Find grad_node and out_rank from AutogradMeta
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
auto rank_info = EagerUtils::unsafe_autograd_meta(tensor)->OutRankInfo();
return grad_node->RegisterGradientHook(
rank_info.first, rank_info.second, std::move(hook));
rank_info.first,
rank_info.second,
std::move(std::make_shared<CppTensorHook>(hook)));
}
void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorVoidHook>&& hook) {
const std::function<void()>& hook) {
if (IsLeafTensor(tensor)) {
VLOG(6) << "Register ReduceHook for leaf tensor";
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
......@@ -46,7 +49,8 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
"with type: GradNodeAccumulation"));
auto accumulation_grad_node =
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(std::move(hook));
accumulation_grad_node->RegisterReduceHook(
std::move(std::make_shared<CppVoidHook>(hook)));
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Only can register reduce hook for leaf Tensor."));
......@@ -90,10 +94,12 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
};
// Append to GradientHooks
RegisterGradientHookForTensor(tensor,
std::make_shared<egr::CppTensorHook>(hook));
RegisterGradientHookForTensor(tensor, hook);
}
}
void RegisterBackwardFinalHook(const std::function<void()>& hook) {
Controller::Instance().RegisterBackwardFinalHook(hook);
}
} // namespace egr_utils_api
} // namespace egr
......@@ -23,11 +23,14 @@ namespace egr_utils_api {
int64_t RegisterGradientHookForTensor(
const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorHook>&& hook);
const std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& hook);
void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorVoidHook>&& hook);
const std::function<void()>& hook);
void RetainGradForTensor(const paddle::experimental::Tensor& tensor);
void RegisterBackwardFinalHook(const std::function<void()>& hook);
} // namespace egr_utils_api
} // namespace egr
......@@ -371,6 +371,12 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
}
VLOG(6) << "Run Backward Final hook size: "
<< egr::Controller::Instance().FinalBackwardHooks().size();
for (auto& hook : egr::Controller::Instance().FinalBackwardHooks()) {
(*hook)();
}
egr::Controller::Instance().ClearFinalBackwardHooks();
if (!is_general_grad) return {};
return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
}
......
......@@ -29,16 +29,16 @@ class TensorHook {
const paddle::experimental::Tensor& var) = 0;
};
class TensorVoidHook {
class VoidHook {
public:
virtual ~TensorVoidHook() = default;
virtual ~VoidHook() = default;
virtual void operator()() = 0;
};
class CppTensorHook : public TensorHook {
public:
explicit CppTensorHook(std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>&& fn)
explicit CppTensorHook(const std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& fn)
: fn_(std::move(fn)) {}
paddle::experimental::Tensor operator()(
......@@ -52,13 +52,14 @@ class CppTensorHook : public TensorHook {
fn_;
};
class CppTensorVoidHook : public TensorVoidHook {
class CppVoidHook : public VoidHook {
public:
explicit CppTensorVoidHook(std::function<void()>&& fn) : fn_(std::move(fn)) {}
explicit CppVoidHook(const std::function<void()>& fn) : fn_(std::move(fn)) {}
void operator()() override { return fn_(); }
private:
std::function<void()> fn_;
};
} // namespace egr
......@@ -328,8 +328,7 @@ TEST(AccumulationNode, Tensor) {
VLOG(6) << "Running Reduce Hook";
};
node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_1));
node->RegisterReduceHook(std::make_shared<egr::CppVoidHook>(reduce_hook_1));
// operator()
paddle::experimental::Tensor _ret = node->operator()(et0_vec)[0][0];
......@@ -354,8 +353,7 @@ TEST(AccumulationNode, Tensor) {
ret_et0_ptr[0] = 100.0; // set to 100.0
VLOG(6) << "Running Reduce Hook";
};
node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_2));
node->RegisterReduceHook(std::make_shared<egr::CppVoidHook>(reduce_hook_2));
node->ApplyReduceHooks();
// Check ApplyReduceHooks result
......
......@@ -256,8 +256,8 @@ TEST(FwdBwdJoint, GradientHook) {
true /*bias_after_scale*/,
true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out0); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(
out0, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out0,
hook_function); // hook: +5
// Run Forward Node 1
float scale1 = 5.0;
......@@ -265,8 +265,8 @@ TEST(FwdBwdJoint, GradientHook) {
paddle::experimental::Tensor out1 = egr::scale(
out0, scale1, bias1, true /*bias_after_scale*/, true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out1); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(
out1, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out1,
hook_function); // hook: +5
// Run Forward Node 2
float scale2 = 10.0;
......@@ -274,8 +274,8 @@ TEST(FwdBwdJoint, GradientHook) {
paddle::experimental::Tensor out2 = egr::scale(
out0, scale2, bias2, true /*bias_after_scale*/, true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out2); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(
out2, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out2,
hook_function); // hook: +5
// 4. Run Backward
std::vector<paddle::experimental::Tensor> outs = {out1, out2};
......
......@@ -95,8 +95,7 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
auto_grad_meta));
egr_utils_api::RegisterGradientHookForTensor(
target_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook_function);
egr_utils_api::RetainGradForTensor(
target_tensor); // result: 1.0 + 3.0 = 4.0
egr_utils_api::RetainGradForTensor(
......@@ -122,8 +121,7 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
tmp_tensor0.mutable_autograd_meta()));
egr_utils_api::RegisterGradientHookForTensor(
leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook_function);
egr_utils_api::RetainGradForTensor(
leaf_tensor); // result: 4.0*5.0 + 3.0 = 23.0
}
......@@ -173,8 +171,7 @@ TEST(RetainGrad, HookAfterRetainGrad) {
auto_grad_meta));
egr_utils_api::RetainGradForTensor(target_tensor); // result: 1.0
egr_utils_api::RegisterGradientHookForTensor(
target_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook_function);
}
// Retain Grad for leaf tensor1
......@@ -193,8 +190,7 @@ TEST(RetainGrad, HookAfterRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
tmp_tensor0.mutable_autograd_meta()));
egr_utils_api::RegisterGradientHookForTensor(
leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook_function);
}
Backward(target_tensors, {});
......
......@@ -89,12 +89,11 @@ void test_sigmoid(bool is_remove_gradient_hook) {
egr_utils_api::RetainGradForTensor(tensor);
VLOG(6) << "Register GradientHook for Tensor";
int64_t hook_id = egr_utils_api::RegisterGradientHookForTensor(
tensor, std::make_shared<CppTensorHook>(hook_function));
int64_t hook_id =
egr_utils_api::RegisterGradientHookForTensor(tensor, hook_function);
VLOG(6) << "Register ReduceHook for Tensor";
egr_utils_api::RegisterReduceHookForTensor(
tensor, std::make_shared<CppTensorVoidHook>(reduce_hook));
egr_utils_api::RegisterReduceHookForTensor(tensor, reduce_hook);
VLOG(6) << "Runing Forward";
auto output_tensor = sigmoid_dygraph_function(tensor, {});
......@@ -161,10 +160,9 @@ void test_elementwiseAdd(bool is_remove_gradient_hook) {
};
egr_utils_api::RetainGradForTensor(Y);
int64_t hook_id = egr_utils_api::RegisterGradientHookForTensor(
Y, std::make_shared<CppTensorHook>(hook_function));
egr_utils_api::RegisterReduceHookForTensor(
Y, std::make_shared<CppTensorVoidHook>(reduce_hook));
int64_t hook_id =
egr_utils_api::RegisterGradientHookForTensor(Y, hook_function);
egr_utils_api::RegisterReduceHookForTensor(Y, reduce_hook);
auto output_tensor = elementwise_add_dygraph_function(X, Y, {});
......@@ -226,10 +224,9 @@ void test_matmul(bool is_remove_gradient_hook) {
};
egr_utils_api::RetainGradForTensor(Y);
int64_t hook_id = egr_utils_api::RegisterGradientHookForTensor(
Y, std::make_shared<CppTensorHook>(hook_function));
egr_utils_api::RegisterReduceHookForTensor(
Y, std::make_shared<CppTensorVoidHook>(reduce_hook));
int64_t hook_id =
egr_utils_api::RegisterGradientHookForTensor(Y, hook_function);
egr_utils_api::RegisterReduceHookForTensor(Y, reduce_hook);
auto output_tensor = matmul_v2_dygraph_function(
X, Y, {{"trans_x", false}, {"trans_y", false}});
......@@ -256,6 +253,59 @@ void test_matmul(bool is_remove_gradient_hook) {
}
}
void test_backward_final_hooks() {
// Prepare Device Contexts
VLOG(6) << "Init Env";
eager_test::InitEnv(paddle::platform::CPUPlace());
VLOG(6) << "Make paddle::experimental::Tensor";
paddle::framework::DDim ddimX = phi::make_ddim({4, 16});
paddle::experimental::Tensor X =
egr_utils_api::CreateTensorWithValue(ddimX,
paddle::platform::CPUPlace(),
phi::DataType::FLOAT32,
phi::DataLayout::NCHW,
3.0,
true);
paddle::framework::DDim ddimY = phi::make_ddim({16, 20});
egr_utils_api::RetainGradForTensor(X);
paddle::experimental::Tensor Y =
egr_utils_api::CreateTensorWithValue(ddimY,
paddle::platform::CPUPlace(),
phi::DataType::FLOAT32,
phi::DataLayout::NCHW,
2.0,
true);
VLOG(6) << "Make ReduceHook function";
auto backward_final_hook = [&](void) -> void {
auto* t_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(X.impl())->data<float>();
VLOG(6) << "Run Target Backward Hook";
for (int i = 0; i < X.numel(); i++) {
t_ptr[i] = 100.0; // set to 100.0
}
};
VLOG(6) << "Register Backward Final Hook";
egr_utils_api::RegisterBackwardFinalHook(backward_final_hook);
VLOG(6) << "Runing Forward";
auto output_tensor = matmul_v2_dygraph_function(
X, Y, {{"trans_x", false}, {"trans_y", false}});
auto res = sigmoid_dygraph_function(output_tensor, {});
VLOG(6) << "Finish Forward";
eager_test::CompareTensorWithValue<float>(X, 3.0);
std::vector<paddle::experimental::Tensor> target_tensors = {output_tensor};
VLOG(6) << "Runing Backward";
Backward(target_tensors, {});
VLOG(6) << "Finish Backward";
eager_test::CompareTensorWithValue<float>(X, 100.0);
}
TEST(Hook_intermidiate, Sigmoid) {
// True or false represents whether to call RemoveGradientHook
test_sigmoid(true);
......@@ -271,6 +321,8 @@ TEST(Hook_intermidiate, Matmul_v2) {
test_matmul(true);
test_matmul(false);
}
TEST(Hook_intermidiate, BackwardFinal) { test_backward_final_hooks(); }
} // namespace egr
USE_OP_ITSELF(sigmoid);
......
......@@ -907,12 +907,27 @@ static PyObject* eager_api_to_uva_tensor(PyObject* self,
}
#endif
static PyObject* eager_api__add_backward_final_hook(PyObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PyObject* hook_func = PyTuple_GET_ITEM(args, 0);
egr::Controller::Instance().RegisterBackwardFinalHook(
std::make_shared<PyVoidHook>(hook_func));
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyMethodDef variable_functions[] = {
// TODO(jiabin): Remove scale when we have final state tests
{"scale",
(PyCFunction)(void (*)(void))eager_api_scale,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_add_backward_final_hook",
(PyCFunction)(void (*)(void))eager_api__add_backward_final_hook,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"run_backward",
(PyCFunction)(void (*)(void))eager_api_run_backward,
METH_VARARGS | METH_KEYWORDS,
......
......@@ -57,86 +57,6 @@ typedef SSIZE_T ssize_t;
namespace paddle {
namespace pybind {
namespace py = ::pybind11;
class PyTensorHook : public egr::TensorHook {
public:
explicit PyTensorHook(PyObject* func) : py_func_(func) {
Py_INCREF(py_func_);
}
~PyTensorHook() {
py::gil_scoped_acquire gil;
Py_DECREF(py_func_);
}
paddle::experimental::Tensor operator()(
const paddle::experimental::Tensor& var) override {
py::gil_scoped_acquire gil;
VLOG(3) << "Call PyTensorHook for var " << var.name();
PyObject* res = nullptr;
try {
PyObject* p_tmp_var = ToPyObject(var);
res = PyObject_CallFunctionObjArgs(py_func_, p_tmp_var, nullptr);
Py_DECREF(p_tmp_var);
} catch (platform::EnforceNotMet& e) {
throw std::move(e);
} catch (std::exception& e) {
PADDLE_THROW(platform::errors::Unavailable(
"Hook function of Tensor raises an exception: %s.", e.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Hook function of Tensor raises an unknown exception."));
}
PADDLE_ENFORCE_NOT_NULL(res,
platform::errors::Unavailable(
"Hook function of Tensor return a nullptr."));
if (res == Py_None) {
return var;
}
auto res_tensor = reinterpret_cast<TensorObject*>(res)->tensor;
Py_DECREF(res);
return res_tensor;
}
private:
PyObject* py_func_;
};
class PyTensorVoidHook : public egr::TensorVoidHook {
public:
explicit PyTensorVoidHook(PyObject* func) : py_func_(func) {
Py_INCREF(py_func_);
}
~PyTensorVoidHook() {
py::gil_scoped_acquire gil;
Py_DECREF(py_func_);
}
void operator()() override {
py::gil_scoped_acquire gil;
VLOG(3) << "Call PyTensorVoidHook";
try {
PyObject_CallFunctionObjArgs(py_func_, nullptr);
} catch (platform::EnforceNotMet& e) {
throw std::move(e);
} catch (std::exception& e) {
PADDLE_THROW(platform::errors::Unavailable(
"Hook function of Tensor raises an exception: %s.", e.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Hook function of Tensor raises an unknown exception."));
}
}
private:
PyObject* py_func_;
};
extern void InitTensorWithNumpyValue(TensorObject* self,
const pybind11::object& array,
const paddle::platform::Place& place,
......@@ -1363,7 +1283,7 @@ static PyObject* tensor_register_reduce_hook(TensorObject* self,
auto accumulation_grad_node =
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(
std::make_shared<PyTensorVoidHook>(hook_func));
std::make_shared<PyVoidHook>(hook_func));
RETURN_PY_NONE
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope_guard.h"
......@@ -1427,5 +1428,54 @@ paddle::DataType CastPyArg2DataType(PyObject* obj,
framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos);
return framework::TransToPhiDataType(type);
}
paddle::experimental::Tensor PyTensorHook::operator()(
const paddle::experimental::Tensor& var) {
py::gil_scoped_acquire gil;
VLOG(3) << "Call PyTensorHook for var " << var.name();
PyObject* res = nullptr;
try {
PyObject* p_tmp_var = ToPyObject(var);
res = PyObject_CallFunctionObjArgs(py_func_, p_tmp_var, nullptr);
Py_DECREF(p_tmp_var);
} catch (platform::EnforceNotMet& e) {
throw std::move(e);
} catch (std::exception& e) {
PADDLE_THROW(platform::errors::Unavailable(
"Hook function of Tensor raises an exception: %s.", e.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Hook function of Tensor raises an unknown exception."));
}
PADDLE_ENFORCE_NOT_NULL(res,
platform::errors::Unavailable(
"Hook function of Tensor return a nullptr."));
if (res == Py_None) {
return var;
}
auto res_tensor = reinterpret_cast<TensorObject*>(res)->tensor;
Py_DECREF(res);
return res_tensor;
}
void PyVoidHook::operator()() {
py::gil_scoped_acquire gil;
VLOG(3) << "Call PyVoidHook";
try {
PyObject_CallFunctionObjArgs(py_func_, nullptr);
} catch (platform::EnforceNotMet& e) {
throw std::move(e);
} catch (std::exception& e) {
PADDLE_THROW(platform::errors::Unavailable(
"Hook function of Tensor raises an exception: %s.", e.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Hook function of Tensor raises an unknown exception."));
}
}
} // namespace pybind
} // namespace paddle
......@@ -17,6 +17,7 @@ typedef SSIZE_T ssize_t;
#include <Python.h>
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/jit/base_function.h"
......@@ -36,6 +37,7 @@ class Scope;
}
namespace pybind {
namespace py = ::pybind11;
#define RETURN_PY_NONE \
Py_INCREF(Py_None); \
return Py_None;
......@@ -110,6 +112,39 @@ PyObject* ToPyObject(
const std::unordered_map<std::string, std::vector<std::string>>& value);
PyObject* ToPyObject(const std::unordered_map<std::wstring, int>& value);
class PyTensorHook : public egr::TensorHook {
public:
explicit PyTensorHook(PyObject* func) : py_func_(func) {
Py_INCREF(py_func_);
}
~PyTensorHook() {
py::gil_scoped_acquire gil;
Py_DECREF(py_func_);
}
paddle::experimental::Tensor operator()(
const paddle::experimental::Tensor& var) override;
private:
PyObject* py_func_;
};
class PyVoidHook : public egr::VoidHook {
public:
explicit PyVoidHook(PyObject* func) : py_func_(func) { Py_INCREF(py_func_); }
~PyVoidHook() {
py::gil_scoped_acquire gil;
Py_DECREF(py_func_);
}
void operator()() override;
private:
PyObject* py_func_;
};
template <typename Tuple, size_t N>
struct TupleTensorResult {
static void Run(const Tuple& out, PyObject* result) {
......
......@@ -639,5 +639,35 @@ class TestTensorRegisterBackwardHook(unittest.TestCase):
self.func_register_backward_hook_for_var_without_gradient()
class TestRegsiterBackwardFinalHook(unittest.TestCase):
def setUp(self):
self.devices = ["cpu"]
if paddle.is_compiled_with_cuda():
self.devices.append("gpu")
def test_register_backward_hook(self):
global HOOK_INIT_VALUE
global HOOK_IS_CALLED
for device in self.devices:
np_x = np.random.rand(4, 16).astype("float32")
np_y = np.random.rand(16, 20).astype("float32")
x = paddle.to_tensor(np_x, stop_gradient=False)
y = paddle.to_tensor(np_y, stop_gradient=False)
core.eager._add_backward_final_hook(global_void_hook)
out = paddle.matmul(x, y)
out = paddle.sum(out)
out.backward()
self.assertEqual(HOOK_INIT_VALUE, 20)
self.assertTrue(HOOK_IS_CALLED)
# reset initial value
HOOK_INIT_VALUE = 10
HOOK_IS_CALLED = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册