From dbeb3ea422acaf888684c588066b12fbfce9d52c Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 31 Mar 2021 21:44:36 -0500 Subject: [PATCH] Refactor and simplify hook design & add Tensor.register_hook API (#31775) * refactor and simplify hook design * fix reducer add hook error * add Tensor.register_hook basic impl * refine prepare data impl * revert prepare data change * support register_hook for Tensor * add hook test in model * polish tests and doc example * fix double grad test failed * remove reduce hook func * fix set empty error * polish code by comments * change reduce_hook to mutable_hook * remove useless tmp_ins * fix shape code format error * fix shape code format error --- paddle/fluid/imperative/basic_engine.cc | 79 +++- .../fluid/imperative/gradient_accumulator.cc | 61 ++- .../fluid/imperative/gradient_accumulator.h | 70 ++- paddle/fluid/imperative/hooks.h | 196 ++------- paddle/fluid/imperative/layer.h | 21 + paddle/fluid/imperative/op_base.h | 2 - .../fluid/imperative/partial_grad_engine.cc | 4 + paddle/fluid/imperative/reducer.cc | 8 +- paddle/fluid/imperative/tests/test_hooks.cc | 20 +- paddle/fluid/imperative/variable_wrapper.h | 120 ++--- paddle/fluid/pybind/imperative.cc | 125 ++++-- .../fluid/dygraph/varbase_patch_methods.py | 100 ++++- .../unittests/test_tensor_register_hook.py | 413 ++++++++++++++++++ 13 files changed, 863 insertions(+), 356 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_register_hook.py diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 29ba549868..9e46af9cb7 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -141,17 +141,6 @@ void BasicEngine::PrepareGradAccumulators( << var.get() << ") that don't have grad node with reference count " << accumulator->RefCnt(); - - if (var->HasLeafHooks()) { - VLOG(3) << "Grad variable wrapper (" << var->Name() - << ") has leaf grad hooks."; - PADDLE_ENFORCE_NE( - var->HasGradNode(), true, - platform::errors::PermissionDenied( - "Only leaf Tensor's gradient can append hook to " - "Gradientaccumulator.")); - accumulator->SetPostHooks(var->GetLeafHooks()); - } } else { // Because Inplace op overwrites the grad_node of the input grad_var. So // only the information of grad_pending_node can be used to find the @@ -262,6 +251,30 @@ void BasicEngine::PrepareDeps() { } } +static std::shared_ptr> CallGradientHooks( + const NameVarMap& bwd_ins, const std::string& op_type) { + std::shared_ptr> tmp_ins_ptr = nullptr; + for (const auto& pair : bwd_ins) { + for (size_t i = 0; i < pair.second.size(); ++i) { + auto& var = pair.second[i]; + if (var->HasHook()) { + if (tmp_ins_ptr == nullptr) { + tmp_ins_ptr = std::make_shared>(bwd_ins); + } + VLOG(3) << "Call " << var->GetHooks().size() << " hooks of " << op_type + << "'s input `" << pair.first << "`'s var `" << var->Name() + << "`."; + auto tmp_var = var; + for (const auto& hook_pair : var->GetHooks()) { + tmp_var = (*hook_pair.second)(tmp_var); + } + (*tmp_ins_ptr)[pair.first][i] = tmp_var; + } + } + } + return tmp_ins_ptr; +} + void BasicEngine::Execute() { if (init_node_ == nullptr) { return; @@ -292,10 +305,15 @@ void BasicEngine::Execute() { auto& bwd_ins = cur_op.GetInsMap(); auto& bwd_outs = cur_op.GetOutsMap(); + /** + * [ Why need temporary outputs here? ] + * + * - construct the temp output map, avoid to disrupt graph + * - replace the element in the map by temp var, because a + * var may be coresponding to several grad var in one op + */ NameVarMap tmp_outs(bwd_outs); - // 1. construct the temp output map, avoid to disrupt graph - // 2. replace the element in the map by temp var, because a - // var may be coresponding to several grad var in one op + for (auto& pair : tmp_outs) { if (!pair.second.IsGrad()) { continue; @@ -408,10 +426,28 @@ void BasicEngine::Execute() { } } + /** + * [ Why need temporary inputs here? ] + * + * - Hook execution should not change original input tensor. + * User can register hook for Tensor's gradient, It is expected + * that the hook only affects the gradient of the backward + * propagation, and does not affect the gradient value input + * as the hook. + * - use `tmp_ins_ptr`, only copy bwd_ins when the var in bwd_ins + * hold hooks + */ + auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type()); + { VLOG(3) << "Start to execute grad op " << cur_op.Type(); - OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), - cur_op.place()); + if (tmp_ins_ptr == nullptr) { + OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), + cur_op.place()); + } else { + OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, cur_op.Attrs(), + cur_op.place()); + } } for (auto& pair : inplace_output_grad_var_list_) { @@ -428,15 +464,14 @@ void BasicEngine::Execute() { if (!accumulator->SumGradCompleted()) { continue; } - // 1. Call Hooks for **inner_var_** + // 1. Call Hooks for `inner_var_` + accumulator->CallGradientHooks(); - // 2. Sum Gradient with Previous Graph + // 2. Sum Gradient `inner_var_` to `var_` of Current or Previous Graph accumulator->AccumulateGrad(); - // 3. Call backward Hooks for **var_** - if (accumulator->HasPostHooks()) { - accumulator->CallBackwardPostHooks(); - } + // 3. Call backward Hooks for `var_` + accumulator->CallReduceHooks(); } need_accu_var_list_.clear(); diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index b9df88b1f1..df5ff750c9 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -384,8 +384,8 @@ static platform::Place GetPlaceOfVar( void GradientAccumulator::AccumulateGrad() { /** - * If the gradient has been calculated by previous graph, - * it should be added to the previous graph result. + * If the leaf gradient has been calculated done, the inner_var_ + * should be added to the var_. */ if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) { return; @@ -396,7 +396,7 @@ void GradientAccumulator::AccumulateGrad() { "this auto-grad")); PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true, platform::errors::InvalidArgument( - "Interior var of Leaf tensor should be initialized.")); + "Interior var of Leaf tensor should be initialized.")); auto* src = inner_var_->MutableVar(); auto* dst = var_->MutableVar(); if (!var_->IsEmpty()) { @@ -427,10 +427,65 @@ void GradientAccumulator::AccumulateGrad() { *(dst) = std::move(*src); var_->SetType(inner_var_->Type()); var_->SetDataType(inner_var_->DataType()); + var_->SetIsEmpty(false); } inner_var_.reset(); } +void GradientAccumulator::CallGradientHooks() { + PADDLE_ENFORCE_EQ(var_->IsLeafGrad(), true, + platform::errors::Unavailable( + "Only leaf gradient Tensor can deal with by gradient " + "hook in gradient accumulator.")); + PADDLE_ENFORCE_EQ( + SumGradCompleted(), true, + platform::errors::PreconditionNotMet( + "Only can call gradient hooks after sum gradient completed.")); + PADDLE_ENFORCE_EQ( + HasInnerVar(), true, + platform::errors::PreconditionNotMet( + "Leaf Tensor's inner var is nullptr when call gradient hook.")); + PADDLE_ENFORCE_EQ( + inner_var_->Var().IsInitialized(), true, + platform::errors::PreconditionNotMet("Leaf Tensor's inner var " + "is not initialized when " + "call gradient hook.")); + if (var_->HasHook()) { + VLOG(3) << "Call " << var_->GetHooks().size() + << " hooks of leaf gradient accumulator's inner var `" + << var_->Name() << "`."; + auto tmp_var = inner_var_; + VLOG(3) << "Input var " << var_->Name() << "'s hook size - " + << var_->GetHooks().size(); + for (const auto& hook_pair : var_->GetHooks()) { + tmp_var = (*hook_pair.second)(tmp_var); + } + inner_var_ = tmp_var; + } +} + +void GradientAccumulator::CallReduceHooks() { + PADDLE_ENFORCE_EQ( + var_->IsLeafGrad(), true, + platform::errors::Unavailable("Only leaf gradient Tensor can deal with " + "by reduce hook in gradient accumulator.")); + PADDLE_ENFORCE_EQ(SumGradCompleted(), true, + platform::errors::PreconditionNotMet( + "Only can call reduce hooks after the gradient " + "summation is completed in current batch.")); + PADDLE_ENFORCE_EQ(HasInnerVar(), false, + platform::errors::PreconditionNotMet( + "Only can call reduce hooks after the " + "gradient accumulation is completed in " + "current batch or across batchs.")); + if (var_->HasMutableHook()) { + for (const auto& hook : var_->GetMutableHooks()) { + VLOG(3) << "call gradient accumulator backward hooks."; + (*hook)(var_); + } + } +} + void EagerGradientAccumulator::SumGrad(std::shared_ptr var, size_t trace_id, bool unchange_input) { /** diff --git a/paddle/fluid/imperative/gradient_accumulator.h b/paddle/fluid/imperative/gradient_accumulator.h index e2dabc06a7..6411dce440 100644 --- a/paddle/fluid/imperative/gradient_accumulator.h +++ b/paddle/fluid/imperative/gradient_accumulator.h @@ -40,8 +40,8 @@ class GradientAccumulator { } // inner_var_ record the grad of this auto-grad. - // Only need to generate inner var for non-empty leaf-tensor. - if (var->IsLeafGrad() && !var->IsEmpty()) { + // Only need to generate inner var for leaf-tensor. + if (var->IsLeafGrad()) { inner_var_ = std::make_shared(var->Name()); inner_var_->SetType(var->Type()); inner_var_->SetDataType(var->DataType()); @@ -52,9 +52,6 @@ class GradientAccumulator { << ") to store result of this Graph"; } - // TODO(zhouwei): fix Tensor.clear_gradient() bug, remove this hard flag - var->SetIsEmpty(false); - // var_ is the final grad, processed by hooks and grad accumulation var_ = var; } @@ -93,42 +90,38 @@ class GradientAccumulator { inline bool HasInnerVar() const { return inner_var_ != nullptr; } - /* Hook related methods */ - inline bool HasPostHooks() const { return !post_hooks_.expired(); } - - void SetPostHooks(const std::shared_ptr& hooks) { - PADDLE_ENFORCE_NOT_NULL( - hooks, platform::errors::InvalidArgument( - "The hook set to GradientAccumulator is nullptr.")); - - auto shared_hooks = post_hooks_.lock(); - if (shared_hooks != hooks) { - PADDLE_ENFORCE_EQ( - shared_hooks, nullptr, - platform::errors::PermissionDenied( - "Cannot set post hooks twice to GradientAccumulator.")); - post_hooks_ = hooks; - } - } - // void CallHooks(){} - // ** inner_var_ ** - // function that Sum Gradient with Previous Graph void AccumulateGrad(); - // call backward post hooks, such as reduce hook - void CallBackwardPostHooks() { - PADDLE_ENFORCE_NE( - post_hooks_.expired(), true, - platform::errors::NotFound( - "The post hooks of GradientAccumulator for Tensor `%s` expired.", - var_->Name())); - auto shared_hooks = post_hooks_.lock(); - for (const auto& hook : shared_hooks->backward_hooks()) { - VLOG(3) << "call gradient accumulator backward hooks."; - (*hook)(var_); - } - } + /** [ Hook related methods ] + * + * [Why need two types of VariableWrapperHook? ] + * + * There are two types of gradient accumulation: + * 1. Gradient accumulation in same batch + * 2. Gradient accumulation across batchs + * The order of execution between Hooks and gradient accumulation: + + * [ Gradient accumulation in same batch] + * | + * [ leaf GradVarBase hooks ] + * | + * [ Gradient accumulation across batchs ] + * | + * [ Gradient reduce / allreduce hooks ] + + * Because we currently intend to accumulate these two gradient + * accumulation in one GradientAccumulator, We must distinguish between + * two types of hooks. + + * And the InplaceVariableWrapperHook does not allow users to register + * directly, and is currently only used to support the reduce strategy of + * parallel multi-card training. + */ + + void CallGradientHooks(); + + void CallReduceHooks(); protected: VariableWrapper* var_; @@ -137,7 +130,6 @@ class GradientAccumulator { std::shared_ptr inner_var_; size_t ref_cnt_{0}; size_t cur_cnt_{0}; - std::weak_ptr post_hooks_; }; class EagerGradientAccumulator : public GradientAccumulator { diff --git a/paddle/fluid/imperative/hooks.h b/paddle/fluid/imperative/hooks.h index 1211ec6ae6..4d59298aed 100644 --- a/paddle/fluid/imperative/hooks.h +++ b/paddle/fluid/imperative/hooks.h @@ -18,100 +18,67 @@ #include #include #include - -#include "paddle/fluid/imperative/type_defs.h" -#include "paddle/fluid/platform/macros.h" - namespace paddle { namespace imperative { class VariableWrapper; -/** [ Basic hook classes ] - * s - * @brief OpBasePreHook is executed before the grad OpBase is executed, +/** [ Const VariableWrapper Hook: Pre hook functor of OpBase ] + * + * @brief This hook functor is executed before the grad OpBase is executed, * taking the input of the current grad OpBase as input, and * executing python hooks (user-defined) or C++ hooks (developer-defined) * to achieve the purpose of custom operations on the interior VarBase * gradient. * - * @note OpBasePreHook will not change the input gradient VarBase. + * @note This hook functor will not change the input gradient VarBase. * * @note [Why need to be OpBase `PreHook`, why not `PostHook`?] * - * If set OpBase post hook, when the op executed end, the op's output - * gradient may not be the final state, because it may need other op's - * gradient output to accumulated to it. But before op can be executed, - * the gradient output must have been accumulated to final value. + * 1. We expect If set OpBase post hook, when the op executed end, the + * op's output gradient may not be the final state, because it may need + * other op's gradient output to accumulated to it. But before op can + * be executed, the gradient output must have been accumulated to final + * value. + * 2. We don’t want the hook to change its input Tensor value, so now + * we can't call all hooks in GradAccumulator. * * @note [Why only can be used for interior VarBase?] * * Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf * GradVarBase has no next OpBase to executed, so if need to deal with - * the leaf GradVarBase, cannot use OpBasePreHook. For this case, we - * deal with by GradAccumulatorPostHook. + * the leaf GradVarBase, cannot use this hook functor. For this case, we + * deal with by other inplace hook method. */ -class OpBasePreHook { +class VariableWrapperHook { public: - virtual ~OpBasePreHook() = default; - virtual VariableWrapperList operator()( - const VariableWrapperList& grad_inputs) = 0; + virtual ~VariableWrapperHook() = default; + virtual std::shared_ptr operator()( + const std::shared_ptr& var) = 0; }; -/** - * @brief GradAccumulatorPostHook is the Hook that operates on the current +/** [ Inplace VariableWrapper Hook: Post hook functor of GradAccumulator ] + * + * @brief This hook functor is the Hook that operates on the current * gradientafter the GradientAccumulator has accumulated the gradient. * Leaf GradVarBase has no next OpBase, if we want to register hook * for it, we also need to wait until the leaf GradVarBase accumulation * is completed, so we can add post hook to GradientAccumulator. * - * @note GradAccumulatorPostHook will change the grad VarBase value. + * @note This hook functor will change the grad VarBase value. * - * @note Only allow leaf VarBase hold GradientAccumulatorPostHook. + * @note Only allow leaf VarBase hold call this hook functor. */ -class GradAccumulatorPostHook { +class InplaceVariableWrapperHook { public: - virtual ~GradAccumulatorPostHook() = default; + virtual ~InplaceVariableWrapperHook() = default; virtual void operator()(VariableWrapper* var) = 0; }; -/** [ Hook for cpp functions ] - * - * Here we design three C++ hooks; - * 1. CppOpBasePreHook (Implement later): - * - used for developer-defined C++ interior VarBase hooks - * 2. CppGradAccumulatorPostHook (Implement later): - * - used for developer-defined C++ leaf VarBase hooks - * 3. LambdaGradAccumulatorPostHook: - * - used for VarBase reduce in parallel training - * - * @note [Why need two types of GradAccumulatorPostHook? ] - * - * There are two types of gradient accumulation: - * 1. Gradient accumulation in same batch - * 2. Gradient accumulation across batchs - * The order of execution between Hooks and gradient accumulation: - * - * [ Gradient accumulation in same batch] - * | - * [ leaf GradVarBase hooks ] - * | - * [ Gradient accumulation across batchs ] - * | - * [ Gradient reduce / allreduce] - * - * Because we currently intend to accumulate these two gradient - * accumulation in one GradientAccumulator, We must distinguish between - * two types of hooks. - * - * And the LambdaGradAccumulatorPostHook does not allow users to register - * directly, and is currently only used to support the reduce strategy of - * parallel multi-card training. - */ -class LambdaGradAccumulatorPostHook : public GradAccumulatorPostHook { +class LambdaInplaceVariableWrapperHook : public InplaceVariableWrapperHook { public: - explicit LambdaGradAccumulatorPostHook( - std::function fn) + explicit LambdaInplaceVariableWrapperHook( + std::function&& fn) : fn_(std::move(fn)) {} void operator()(VariableWrapper* var) override { fn_(var); } @@ -120,114 +87,5 @@ class LambdaGradAccumulatorPostHook : public GradAccumulatorPostHook { std::function fn_; }; -/* Hooks for python function: in pybind/imperative.cc */ - -/** Add Python Hooks later: - * - PyOpBasePreHook (Implement later): used for user-defined interior python - * VarBase hooks - * - PyGradAccumulatorPostHook (Implement later): used for user-defined leaf - * python VarBase hooks - */ - -/** [ Hook Pipeline classes ] - * - * @note [Why need hook pipeline classes?] - * - * There are 2 purposes for adding Hook pipeline here: - * - * 1. Make the code implementation cleaner. - * - * If there are no Hook pipeline, we need to add 3 hook vector into - * VariableWrapper, 1 hook vector into OpBase, 2 hook vector into - * GradientAccumulator, like: - * - * - VariableWrapper: - * std::vector> - * interior_var_hooks_; - * std::vector> - * leaf_var_hooks_; - * std::vector> - * backward_hooks_; - * - * - OpBase: - * std::vector> - * interior_var_hooks_; - * - * - GradientAccumulator: - * std::vector> - * leaf_var_hooks_; - * std::vector> - * backward_hooks_; - * - * This seems more complicated, and std::vector> - * is not easy to destruct. - * - * 2. Make the code easier to understand. - * - * From these two packages, we can clearly understand that we - * have two types of Hooks, respectively for the interior - * gradient var and leaf gradient var inside the backward - * calculation graph. - */ - -class InteriorVarHookPipeline { - public: - InteriorVarHookPipeline() = default; - - void add_hook(std::unique_ptr&& hook) { - hooks_.emplace_back(std::move(hook)); - } - - const std::vector>& hooks() const { - return hooks_; - } - - std::vector>& hooks() { return hooks_; } - - private: - std::vector> hooks_; - - DISABLE_COPY_AND_ASSIGN(InteriorVarHookPipeline); -}; - -class LeafVarHookPipeline { - public: - LeafVarHookPipeline() = default; - - void add_hook(std::unique_ptr&& hook) { - hooks_.emplace_back(std::move(hook)); - } - - const std::vector>& hooks() const { - return hooks_; - } - - std::vector>& hooks() { - return hooks_; - } - - void add_backward_hook(std::unique_ptr&& hook) { - backward_hooks_.emplace_back(std::move(hook)); - } - - const std::vector>& backward_hooks() - const { - return backward_hooks_; - } - - std::vector>& backward_hooks() { - return backward_hooks_; - } - - private: - std::vector> hooks_; - // NOTE: the `backward` here means the `whole backward process`, - // the `backward_hooks_` need to be executed after the `whole backward - // process`. - std::vector> backward_hooks_; - - DISABLE_COPY_AND_ASSIGN(LeafVarHookPipeline); -}; - } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index ff5a780a5f..f87db41576 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -30,6 +30,7 @@ #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/flags.h" +#include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/saved_variable_wrapper_list.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/variable_wrapper.h" @@ -220,6 +221,26 @@ class VarBase { void BumpInplaceVersion(); + /* Hook related method: now only used for GradVarBase */ + bool HasHook() const { return var_->HasHook(); } + + int64_t AddHook(std::shared_ptr&& hook) { + return var_->AddHook( + std::forward>(hook)); + } + + bool RemoveHook(const int64_t& hook_id) { return var_->RemoveHook(hook_id); } + + const std::map>& GetHooks() + const { + return var_->GetHooks(); + } + + void AddMutableHook(std::shared_ptr&& hook) { + var_->AddMutableHook( + std::forward>(hook)); + } + private: /** * NOTE(zengjinle): never remove the const qualifier of `var_` if you are diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 2b7642ae7c..0164ff9313 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -177,8 +177,6 @@ class OpBase { std::unique_ptr op_; platform::Place place_; size_t id_{-1UL}; - - std::weak_ptr pre_hooks_; }; class GradOpNode { diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 8dd8cafc83..3da3a05ed1 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -369,6 +369,10 @@ class GradientAccumulationInfo { *is_finished = (cur_ref_cnt_ == total_ref_cnt_); accumulator_->SumGrad(grad_var_partial, trace_id, unchange_input); + if (*is_finished && accumulator_->HasInnerVar()) { + accumulator_->AccumulateGrad(); + } + if (create_graph_) { VLOG(10) << "Store partial grad grad for double grad " << mapped_grad_var_->Name(); diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index e8b531d35c..4b18886821 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -310,11 +310,9 @@ Reducer::Reducer(const std::vector> &vars, for (size_t global_var_index = 0; global_var_index < vars_.size(); ++global_var_index) { auto var = vars_[global_var_index]; - var->SharedVar()->AddGradVarLeafBackwardHook( - std::unique_ptr( - new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) { - this->AddDistHook(global_var_index); - }))); + var->GradVarBase()->AddMutableHook( + std::make_shared([=]( + VariableWrapper *grad) { this->AddDistHook(global_var_index); })); var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index; } } diff --git a/paddle/fluid/imperative/tests/test_hooks.cc b/paddle/fluid/imperative/tests/test_hooks.cc index 7bf5f87668..9b75fac0ca 100644 --- a/paddle/fluid/imperative/tests/test_hooks.cc +++ b/paddle/fluid/imperative/tests/test_hooks.cc @@ -74,16 +74,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) { mul_attr_map["use_mkldnn"] = false; // add GradAccumulatorPostHook - auto x_var_wrapper = x->SharedVar(); - x_var_wrapper->AddGradVarLeafBackwardHook( - std::unique_ptr( - new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) { + x->GradVarBase()->AddMutableHook( + std::make_shared( + [=](VariableWrapper* grad) { auto* grad_tensor = grad->MutableVar()->GetMutable(); for (int i = 0; i < grad_tensor->numel(); ++i) { grad_tensor->mutable_data(place)[i] *= 2.0; } - }))); + })); // 2. forward tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); @@ -151,17 +150,16 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() { memory::Copy(place, mutable_z, place, src_data.data(), sizeof(float) * src_data.size()); - // add GradAccumulatorPostHook - auto x_var_wrapper = x->SharedVar(); - x_var_wrapper->AddGradVarLeafBackwardHook( - std::unique_ptr( - new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) { + // add ReduceBackwardHook + x->GradVarBase()->AddMutableHook( + std::make_shared( + [=](VariableWrapper* grad) { auto* grad_tensor = grad->MutableVar()->GetMutable(); for (int i = 0; i < grad_tensor->numel(); ++i) { grad_tensor->mutable_data(place)[i] *= 2.0; } - }))); + })); // 2. forward var_pair x_pair = var_pair("X", vb_vector(1, x)); diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index b42f25dcc8..7d287c9829 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -27,8 +27,8 @@ namespace paddle { namespace imperative { -class InteriorVarHookPipeline; -class LeafVarHookPipeline; +class VariableWrapperHook; +class InplaceVariableWrapperHook; class VarBase; class GradOpNode; @@ -193,42 +193,6 @@ class VariableWrapper { } } - /* Hook related method: only can be call by GradVarBase */ - - bool HasInteriorHooks() const { return interior_hooks_ != nullptr; } - - bool HasLeafHooks() const { return leaf_hooks_ != nullptr; } - - void AddGradVarInteriorHook(std::unique_ptr&& hook) { - auto interior_hooks = GetGradVarInteriorHooksSafely(); - interior_hooks->add_hook(std::move(hook)); - } - - void AddGradVarLeafHook(std::unique_ptr&& hook) { - auto leaf_hooks = GetGradVarLeafHooksSafely(); - leaf_hooks->add_hook(std::move(hook)); - } - - void AddGradVarLeafBackwardHook( - std::unique_ptr&& hook) { - auto leaf_hooks = GetGradVarLeafHooksSafely(); - leaf_hooks->add_backward_hook(std::move(hook)); - } - - const std::shared_ptr& GetInteriorHooks() const { - return interior_hooks_; - } - - std::shared_ptr& GetInteriorHooks() { - return interior_hooks_; - } - - const std::shared_ptr& GetLeafHooks() const { - return leaf_hooks_; - } - - std::shared_ptr& GetLeafHooks() { return leaf_hooks_; } - uint32_t InplaceVersionSnapshot() const { return inplace_version_snapshot_; } void ResetInplaceVersion() { @@ -255,6 +219,38 @@ class VariableWrapper { return; } + /* Hook related methods */ + bool HasHook() const { return !hooks_.empty(); } + + bool HasMutableHook() const { return !mutable_hooks_.empty(); } + + int64_t AddHook(std::shared_ptr&& hook) { + hooks_.emplace(next_hook_id_, std::move(hook)); + return next_hook_id_++; + } + + bool RemoveHook(const int64_t& hook_id) { + auto remove_cnt = hooks_.erase(hook_id); + if (remove_cnt == 0) { + return false; + } + return true; + } + + const std::map>& GetHooks() + const { + return hooks_; + } + + void AddMutableHook(std::shared_ptr&& hook) { + mutable_hooks_.emplace_back(std::move(hook)); + } + + const std::vector>& + GetMutableHooks() const { + return mutable_hooks_; + } + private: void SetGradVar(const std::shared_ptr& var) { auto shared_var = grad_var_.lock(); @@ -289,41 +285,6 @@ class VariableWrapper { } } - /* Hook related private methods */ - std::shared_ptr GetGradVarSafely() const { - auto shared_grad_var = grad_var_.lock(); - PADDLE_ENFORCE_NOT_NULL( - shared_grad_var, - platform::errors::PermissionDenied( - "Cannot add gradient hook on Tensor without gradient.")); - return shared_grad_var; - } - - std::shared_ptr& GetGradVarInteriorHooksSafely() { - auto shared_grad_var = GetGradVarSafely(); - PADDLE_ENFORCE_EQ(HasGradNode(), true, - platform::errors::PermissionDenied( - "Only interior Tensor in backward can register " - "interior gradient hook.")); - if (shared_grad_var->interior_hooks_ == nullptr) { - shared_grad_var->interior_hooks_ = - std::make_shared(); - } - return shared_grad_var->interior_hooks_; - } - - std::shared_ptr& GetGradVarLeafHooksSafely() { - auto shared_grad_var = GetGradVarSafely(); - PADDLE_ENFORCE_EQ( - HasGradNode(), false, - platform::errors::PermissionDenied( - "Only leaf Tensor in backward can register leaf gradient hook.")); - if (shared_grad_var->leaf_hooks_ == nullptr) { - shared_grad_var->leaf_hooks_ = std::make_shared(); - } - return shared_grad_var->leaf_hooks_; - } - private: framework::Variable var_; std::string name_; @@ -358,11 +319,14 @@ class VariableWrapper { // isn't need bool is_empty_{false}; - // NOTE: only grad var can hold hooks now - // only interior var can hold interior hooks - std::shared_ptr interior_hooks_; - // only leaf var can hold leaf hooks - std::shared_ptr leaf_hooks_; + // NOTE(chenweihang): only grad var can hold hooks now + int64_t next_hook_id_{0}; + // Hooks used to register hook for grad var, support adding and removing, + // key is the accumulated int64_t value + std::map> hooks_; + // Hooks executed after the execution of the entire backward process is over, + // currently only supported for reducing in distributed training + std::vector> mutable_hooks_; }; } // namespace imperative diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 40cf6cd84b..38ba1dc029 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/bkcl_context.h" #include "paddle/fluid/imperative/data_loader.h" +#include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/partial_grad_engine.h" @@ -63,6 +64,65 @@ class Layer : public imperative::Layer { } }; +template +static T PyObjectCast(PyObject *obj) { + try { + return py::cast(py::handle(obj)); + } catch (py::cast_error &) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Python object is not type of %s", typeid(T).name())); + } +} + +class PyVariableWrapperHook : public imperative::VariableWrapperHook { + public: + explicit PyVariableWrapperHook(PyObject *func) : py_func_(func) { + Py_INCREF(py_func_); + } + + ~PyVariableWrapperHook() { + py::gil_scoped_acquire gil; + Py_DECREF(py_func_); + } + + std::shared_ptr operator()( + const std::shared_ptr &var) override { + py::gil_scoped_acquire gil; + VLOG(3) << "Call PyVariableWrapperHook for var " << var->Name(); + + // 1. unpack temp VarBase from VariableWrapper + std::shared_ptr tmp_varbase = + std::make_shared(var); + + // 2. call hook and return + PyObject *res = nullptr; + try { + res = PyObject_CallFunctionObjArgs(py_func_, py::cast(tmp_varbase).ptr(), + nullptr); + } catch (platform::EnforceNotMet &e) { + throw std::move(e); + } catch (std::exception &e) { + PADDLE_THROW(platform::errors::Unavailable( + "Hook function of Tensor raises an exception: %s.", e.what())); + } catch (...) { + PADDLE_THROW(platform::errors::Fatal( + "Hook function of Tensor raises an unknown exception.")); + } + + PADDLE_ENFORCE_NOT_NULL(res, + platform::errors::Unavailable( + "Hook function of Tensor return a nullptr.")); + if (res == Py_None) { + return var; + } + + return PyObjectCast>(res)->SharedVar(); + } + + private: + PyObject *py_func_; +}; + static const platform::Place PyObjectToPlace(const py::object &place_obj) { if (py::isinstance(place_obj)) { return place_obj.cast(); @@ -213,16 +273,6 @@ static std::string GetTypeName(const imperative::VarBase &var) { using PyNameVarBaseMap = std::unordered_map; -template -static T PyObjectCast(PyObject *obj) { - try { - return py::cast(py::handle(obj)); - } catch (py::cast_error &) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Python object is not type of %s", typeid(T).name())); - } -} - // NOTE(zjl): py::handle is a very light wrapper of PyObject *. // Unlike py::object, py::handle does not change reference count of PyObject *. static std::vector> @@ -1023,6 +1073,23 @@ void BindImperative(py::module *m_ptr) { } }, py::call_guard()) + .def("_register_grad_hook", + [](imperative::VarBase &self, const py::handle &hook) { + PADDLE_ENFORCE_EQ( + self.HasGradVar(), true, + platform::errors::InvalidArgument( + "Cannot register hook on a tensor without gradient.")); + return self.GradVarBase()->AddHook( + std::make_shared(hook.ptr())); + }) + .def("_remove_grad_hook", + [](imperative::VarBase &self, int64_t hook_id) { + PADDLE_ENFORCE_EQ( + self.HasGradVar(), true, + platform::errors::InvalidArgument( + "Cannot remove hook on a tensor without gradient.")); + return self.GradVarBase()->RemoveHook(hook_id); + }) .def("cpu", [](const std::shared_ptr &self) { if (platform::is_cpu_place(self->Place())) { @@ -1231,22 +1298,28 @@ void BindImperative(py::module *m_ptr) { &imperative::VarBase::SetOverridedStopGradient) .def_property("persistable", &imperative::VarBase::Persistable, &imperative::VarBase::SetPersistable) - .def_property_readonly( - "shape", - [](imperative::VarBase &self) { - if (self.Var().IsType()) { - return framework::vectorize( - self.Var().Get().dims()); - } else if (self.Var().IsType()) { - return framework::vectorize( - self.Var().Get().value().dims()); - } else { - VLOG(2) << "It is meaningless to get shape of " - "variable type " - << GetTypeName(self); - return std::vector(); - } - }) + .def_property_readonly("shape", + [](imperative::VarBase &self) { + if (self.Var().IsType()) { + return framework::vectorize( + self.Var() + .Get() + .dims()); + } else if (self.Var() + .IsType< + framework::SelectedRows>()) { + return framework::vectorize( + self.Var() + .Get() + .value() + .dims()); + } else { + VLOG(2) << "It is meaningless to get shape of " + "variable type " + << GetTypeName(self); + return std::vector(); + } + }) .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf, R"DOC( Whether a Tensor is leaf Tensor. diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index ac0944c571..e565552632 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -14,6 +14,8 @@ import inspect import numpy as np +import warnings +import weakref import paddle from .. import framework @@ -26,6 +28,34 @@ from .parallel import scale_loss from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE +class TensorHookRemoveHelper(object): + """ + A helper class that for removing Tensor gradient's hook. + """ + + def __init__(self, tensor, hook_id): + self._tensor_ref = weakref.ref(tensor) + self._hook_id = hook_id + + def remove(self): + """ + Remove reference Tensor's hook. + + Returns: + bool: Return True if removed successfully + """ + tensor = self._tensor_ref() + if tensor is not None: + res = tensor._remove_grad_hook(self._hook_id) + if res is True: + return True + else: + warnings.warn( + "The backward hook (ID: %d) of Tensor `%s` you want to remove does not exist or has been removed." + % (self._hook_id, tensor.name), RuntimeWarning) + return False + + def monkey_patch_varbase(): @switch_to_static_graph def _to_static_var(self, to_parameter=False, **kwargs): @@ -211,6 +241,73 @@ def monkey_patch_varbase(): else: return np.array(new_ivar.value().get_tensor()) + @framework.dygraph_only + def register_hook(self, hook): + """ + Registers a backward hook for current Tensor. + + The hook will be called every time the gradient Tensor of current Tensor is computed. + + The hook should not modify the input gradient Tensor, but it can optionally return + a new gradient Tensor which will be used in place of current Tensor's gradient. + + The hook should have the following signature: + + hook(grad) -> Tensor or None + + Args: + hook(function): A backward hook to be registered for Tensor.grad + + Returns: + TensorHookRemoveHelper: A helper object that can be used to remove the registered hook by calling `remove()` method. + + Examples: + .. code-block:: python + + import paddle + + # hook function return None + def print_hook_fn(grad): + print(grad) + + # hook function return Tensor + def double_hook_fn(grad): + grad = grad * 2 + return grad + + x = paddle.to_tensor([0., 1., 2., 3.], stop_gradient=False) + y = paddle.to_tensor([4., 5., 6., 7.], stop_gradient=False) + z = paddle.to_tensor([1., 2., 3., 4.]) + + # one Tensor can register multiple hooks + h = x.register_hook(print_hook_fn) + x.register_hook(double_hook_fn) + + w = x + y + # register hook by lambda function + w.register_hook(lambda grad: grad * 2) + + o = z.matmul(w) + o.backward() + # print_hook_fn print content in backward + # Tensor(shape=[4], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [2., 4., 6., 8.]) + + print("w.grad:", w.grad) # w.grad: [1. 2. 3. 4.] + print("x.grad:", x.grad) # x.grad: [ 4. 8. 12. 16.] + print("y.grad:", y.grad) # y.grad: [2. 4. 6. 8.] + + # remove hook + h.remove() + """ + if self.stop_gradient is True: + raise RuntimeError( + "Cannot register hook on a tensor that stop gradient.") + + hook_id = self._register_grad_hook(hook) + helper = TensorHookRemoveHelper(self, hook_id) + return helper + @property def grad(self): """ @@ -316,7 +413,8 @@ def monkey_patch_varbase(): ("_to_static_var", _to_static_var), ("set_value", set_value), ("block", block), ("backward", backward), ("clear_grad", clear_grad), ("inplace_version", inplace_version), ("grad", grad), - ("gradient", gradient), ("__str__", __str__), ("__repr__", __str__), + ("gradient", gradient), ("register_hook", register_hook), + ("__str__", __str__), ("__repr__", __str__), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"), ("__name__", "Tensor")): setattr(core.VarBase, method_name, method) diff --git a/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py new file mode 100644 index 0000000000..a390dd9d80 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py @@ -0,0 +1,413 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.nn as nn + + +class SimpleNet(nn.Layer): + def __init__(self, in_size, out_size): + super(SimpleNet, self).__init__() + self.linear1 = nn.Linear(in_size, in_size) + self.linear2 = nn.Linear(in_size, out_size) + + def forward(self, x, hook=None, register=False, remove=False): + ret1 = self.linear1(x) + if hook is not None: + if register: + h = ret1.register_hook(hook) + if remove: + h.remove() + ret2 = self.linear2(ret1) + out = paddle.mean(ret2, axis=-1) + return ret1, out + + +class TestTensorRegisterHook(unittest.TestCase): + def setUp(self): + self.seed = 2021 + self.in_size = 10 + self.out_size = 10 + self.batch_size = 4 + self.devices = ["cpu"] + if paddle.is_compiled_with_cuda(): + self.devices.append("gpu") + + def test_hook_for_interior_var(self): + def run_double_hook_for_interior_var(double_hook, removed=False): + for device in self.devices: + paddle.set_device(device) + + x = paddle.to_tensor([0., 1., 2., 3.]) + y = paddle.to_tensor([4., 5., 6., 7.]) + x.stop_gradient = False + y.stop_gradient = False + + w = x + y + w.stop_gradient = False + helper = w.register_hook(double_hook) + + z = paddle.to_tensor([1., 2., 3., 4.]) + z.stop_gradient = False + + o = z.matmul(w) + + # remove hook before backward + if removed: + helper.remove() + + o.backward() + + # z.grad is not affected + self.assertTrue(np.array_equal(z.grad, w.numpy())) + # w.grad is not changed by hook + self.assertTrue(np.array_equal(w.grad, z.numpy())) + # x.grad and y.grad are changed if run hook + self.assertTrue( + np.array_equal(x.grad, + z.numpy() * 2 if not removed else z.numpy())) + self.assertTrue( + np.array_equal(y.grad, + z.numpy() * 2 if not removed else z.numpy())) + + def run_print_hook_for_interior_var(print_hook, removed=False): + for device in self.devices: + paddle.set_device(device) + + x = paddle.to_tensor([0., 1., 2., 3.]) + y = paddle.to_tensor([4., 5., 6., 7.]) + x.stop_gradient = False + y.stop_gradient = False + + w = x + y + w.stop_gradient = False + helper = w.register_hook(print_hook) + + z = paddle.to_tensor([1., 2., 3., 4.]) + z.stop_gradient = False + + o = z.matmul(w) + + # remove hook before backward + if removed: + helper.remove() + + o.backward() + + # all grads are not affected + self.assertTrue(np.array_equal(z.grad, w.numpy())) + self.assertTrue(np.array_equal(w.grad, z.numpy())) + self.assertTrue(np.array_equal(x.grad, z.numpy())) + self.assertTrue(np.array_equal(y.grad, z.numpy())) + + def double_hook(grad): + grad = grad * 2 + print(grad) + return grad + + def print_hook(grad): + print(grad) + + # register hook + run_double_hook_for_interior_var(double_hook) + # register hook and removed + run_double_hook_for_interior_var(double_hook, removed=True) + + # register hook + run_double_hook_for_interior_var(lambda grad: grad * 2) + # register hook and removed + run_double_hook_for_interior_var(lambda grad: grad * 2, removed=True) + + # register hook + run_print_hook_for_interior_var(print_hook) + # register hook and removed + run_print_hook_for_interior_var(print_hook, removed=True) + + def test_hook_for_leaf_var(self): + def run_double_hook_for_leaf_var(double_hook, removed=False): + for device in self.devices: + paddle.set_device(device) + + x = paddle.to_tensor([0., 1., 2., 3.]) + y = paddle.to_tensor([4., 5., 6., 7.]) + x.stop_gradient = False + y.stop_gradient = False + helper = y.register_hook(double_hook) + + w = x + y + w.stop_gradient = False + + z = paddle.to_tensor([1., 2., 3., 4.]) + z.stop_gradient = False + + o = z.matmul(w) + + # remove hook before backward + if removed: + helper.remove() + + o.backward() + + # z.grad, w.grad, x.grad is not affected + self.assertTrue(np.array_equal(z.grad, w.numpy())) + self.assertTrue(np.array_equal(w.grad, z.numpy())) + self.assertTrue(np.array_equal(x.grad, z.numpy())) + # y.grad are changed if run hook + self.assertTrue( + np.array_equal(y.grad, + z.numpy() * 2 if not removed else z.numpy())) + + # register hook + run_double_hook_for_leaf_var(lambda grad: grad * 2) + # register hook and removed + run_double_hook_for_leaf_var(lambda grad: grad * 2, removed=True) + + def test_hook_for_accumulated_grad(self): + def run_double_hook_for_accumulated_grad(double_hook, removed=False): + for device in self.devices: + paddle.set_device(device) + + a = paddle.to_tensor([0., 1., 1., 2.]) + b = paddle.to_tensor([0., 0., 1., 2.]) + a.stop_gradient = False + b.stop_gradient = False + + helper1 = a.register_hook(double_hook) + + x = a + b + x.stop_gradient = False + + helper2 = x.register_hook(double_hook) + + y = paddle.to_tensor([4., 5., 6., 7.]) + z = paddle.to_tensor([1., 2., 3., 4.]) + y.stop_gradient = False + z.stop_gradient = False + + o1 = x + y + o2 = x + z + o1.stop_gradient = False + o2.stop_gradient = False + + o = o1.matmul(o2) + + # remove hook before backward + if removed: + helper1.remove() + helper2.remove() + + o.backward() + + base_grad = np.array([5., 9., 13., 19.]) + # x.grad is not changed + self.assertTrue(np.array_equal(x.grad, base_grad)) + # b.grad is changed by x.hook + self.assertTrue( + np.array_equal(b.grad, base_grad * 2 + if not removed else base_grad)) + # a.grad is changed by x.hook and a.hook + self.assertTrue( + np.array_equal(a.grad, base_grad * 4 + if not removed else base_grad)) + + # register hook + run_double_hook_for_accumulated_grad(lambda grad: grad * 2) + # register hook and removed + run_double_hook_for_accumulated_grad( + lambda grad: grad * 2, removed=True) + + def test_hook_in_model(self): + def run_double_hook_in_model(data, + label, + hook=None, + register=False, + remove=False): + for device in self.devices: + paddle.seed(self.seed) + paddle.set_device(device) + + net = SimpleNet(self.in_size, self.out_size) + loss_fn = nn.MSELoss() + + data = paddle.to_tensor(data) + label = paddle.to_tensor(label) + + ret1, out = net(data, hook, register, remove) + loss = loss_fn(out, label) + loss.backward() + + return ret1.grad, net.linear1.weight.grad, net.linear1.bias.grad + + data = np.random.uniform( + size=[self.batch_size, self.in_size]).astype('float32') + label = np.random.uniform(size=[self.batch_size, 1]).astype('float32') + + # get original value + ret1_grad, linear1_w_grad, linear1_b_grad = run_double_hook_in_model( + data, label) + # get value changed by hook + ret1_grad_hook, linear1_w_grad_hook, linear1_b_grad_hook = run_double_hook_in_model( + data, label, lambda grad: grad * 2, True) + # get value after removing hook + ret1_grad_rm, linear1_w_grad_rm, linear1_b_grad_rm = run_double_hook_in_model( + data, label, lambda grad: grad * 2, True, True) + + # compare original value and with hook + self.assertTrue(np.array_equal(ret1_grad, ret1_grad_hook)) + self.assertTrue(np.array_equal(linear1_w_grad * 2, linear1_w_grad_hook)) + self.assertTrue(np.array_equal(linear1_b_grad * 2, linear1_b_grad_hook)) + + # compare original value and remove hook + self.assertTrue(np.array_equal(ret1_grad, ret1_grad_rm)) + self.assertTrue(np.array_equal(linear1_w_grad, linear1_w_grad_rm)) + self.assertTrue(np.array_equal(linear1_b_grad, linear1_b_grad_rm)) + + def test_multiple_hooks_for_interior_var(self): + def run_multiple_hooks_for_interior_var(device, + hooks, + remove1=False, + remove2=False, + remove3=False): + paddle.set_device(device) + + x = paddle.to_tensor([0., 1., 2., 3.]) + y = paddle.to_tensor([4., 5., 6., 7.]) + x.stop_gradient = False + y.stop_gradient = False + + w = x + y + w.stop_gradient = False + + helpers = [] + for hook in hooks: + helper = w.register_hook(hook) + helpers.append(helper) + + z = paddle.to_tensor([1., 2., 3., 4.]) + z.stop_gradient = False + + o = z.matmul(w) + + if remove1: + helpers[0].remove() + if remove2: + helpers[1].remove() + if remove3: + helpers[2].remove() + + o.backward() + + return z.numpy(), w.grad, x.grad, y.grad + + def double_hook(grad): + return grad * 2 + + hooks = [double_hook, double_hook, double_hook] + + for device in self.devices: + z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var( + device, hooks) + + self.assertTrue(np.array_equal(w_grad, z)) + self.assertTrue(np.array_equal(x_grad, z * 8)) + self.assertTrue(np.array_equal(y_grad, z * 8)) + + z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var( + device, hooks, remove1=True) + + self.assertTrue(np.array_equal(w_grad, z)) + self.assertTrue(np.array_equal(x_grad, z * 4)) + self.assertTrue(np.array_equal(y_grad, z * 4)) + + z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var( + device, hooks, remove2=True) + + self.assertTrue(np.array_equal(w_grad, z)) + self.assertTrue(np.array_equal(x_grad, z * 4)) + self.assertTrue(np.array_equal(y_grad, z * 4)) + + z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var( + device, hooks, remove3=True) + + self.assertTrue(np.array_equal(w_grad, z)) + self.assertTrue(np.array_equal(x_grad, z * 4)) + self.assertTrue(np.array_equal(y_grad, z * 4)) + + z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var( + device, hooks, remove1=True, remove2=True, remove3=True) + + self.assertTrue(np.array_equal(w_grad, z)) + self.assertTrue(np.array_equal(x_grad, z)) + self.assertTrue(np.array_equal(y_grad, z)) + + def test_hook_in_double_grad(self): + def double_print_hook(grad): + grad = grad * 2 + print(grad) + return grad + + x = paddle.ones(shape=[1], dtype='float32') + x.stop_gradient = False + + # hook only works in backward + # for forward var x, the x.grad generated in + # paddle.grad will not deal with by hook + x.register_hook(double_print_hook) + + y = x * x + + # Since y = x * x, dx = 2 * x + dx = paddle.grad( + outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0] + + z = y + dx + self.assertTrue(x.grad is None) + + # If create_graph = True, the gradient of dx + # would be backpropagated. Therefore, + # z = x * x + dx = x * x + 2 * x, and + # x.gradient() = 2 * x + 2 = 4.0 + # after changed by hook: 8.0 + + z.backward() + self.assertTrue(np.array_equal(x.grad, np.array([8.]))) + + def test_remove_one_hook_multiple_times(self): + for device in self.devices: + paddle.set_device(device) + + x = paddle.to_tensor([1., 2., 3., 4.]) + x.stop_gradient = False + + h = x.register_hook(lambda grad: grad * 2) + self.assertTrue(h.remove()) + self.assertFalse(h.remove()) + + def test_register_hook_for_stop_gradient_var(self): + for device in self.devices: + paddle.set_device(device) + + x = paddle.to_tensor([1., 2., 3., 4.]) + + with self.assertRaises(RuntimeError): + x.register_hook(lambda grad: grad * 2) + + +if __name__ == '__main__': + unittest.main() -- GitLab