未验证 提交 dbeb3ea4 编写于 作者: C Chen Weihang 提交者: GitHub

Refactor and simplify hook design & add Tensor.register_hook API (#31775)

* refactor and simplify hook design

* fix reducer add hook error

* add Tensor.register_hook basic impl

* refine prepare data impl

* revert prepare data change

* support register_hook for Tensor

* add hook test in model

* polish tests and doc example

* fix double grad test failed

* remove reduce hook func

* fix set empty error

* polish code by comments

* change reduce_hook to mutable_hook

* remove useless tmp_ins

* fix shape code format error

* fix shape code format error
上级 6b744866
......@@ -141,17 +141,6 @@ void BasicEngine::PrepareGradAccumulators(
<< var.get()
<< ") that don't have grad node with reference count "
<< accumulator->RefCnt();
if (var->HasLeafHooks()) {
VLOG(3) << "Grad variable wrapper (" << var->Name()
<< ") has leaf grad hooks.";
PADDLE_ENFORCE_NE(
var->HasGradNode(), true,
platform::errors::PermissionDenied(
"Only leaf Tensor's gradient can append hook to "
"Gradientaccumulator."));
accumulator->SetPostHooks(var->GetLeafHooks());
}
} else {
// Because Inplace op overwrites the grad_node of the input grad_var. So
// only the information of grad_pending_node can be used to find the
......@@ -262,6 +251,30 @@ void BasicEngine::PrepareDeps() {
}
}
static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
const NameVarMap<VariableWrapper>& bwd_ins, const std::string& op_type) {
std::shared_ptr<NameVarMap<VariableWrapper>> tmp_ins_ptr = nullptr;
for (const auto& pair : bwd_ins) {
for (size_t i = 0; i < pair.second.size(); ++i) {
auto& var = pair.second[i];
if (var->HasHook()) {
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VariableWrapper>>(bwd_ins);
}
VLOG(3) << "Call " << var->GetHooks().size() << " hooks of " << op_type
<< "'s input `" << pair.first << "`'s var `" << var->Name()
<< "`.";
auto tmp_var = var;
for (const auto& hook_pair : var->GetHooks()) {
tmp_var = (*hook_pair.second)(tmp_var);
}
(*tmp_ins_ptr)[pair.first][i] = tmp_var;
}
}
}
return tmp_ins_ptr;
}
void BasicEngine::Execute() {
if (init_node_ == nullptr) {
return;
......@@ -292,10 +305,15 @@ void BasicEngine::Execute() {
auto& bwd_ins = cur_op.GetInsMap();
auto& bwd_outs = cur_op.GetOutsMap();
/**
* [ Why need temporary outputs here? ]
*
* - construct the temp output map, avoid to disrupt graph
* - replace the element in the map by temp var, because a
* var may be coresponding to several grad var in one op
*/
NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
// 1. construct the temp output map, avoid to disrupt graph
// 2. replace the element in the map by temp var, because a
// var may be coresponding to several grad var in one op
for (auto& pair : tmp_outs) {
if (!pair.second.IsGrad()) {
continue;
......@@ -408,10 +426,28 @@ void BasicEngine::Execute() {
}
}
/**
* [ Why need temporary inputs here? ]
*
* - Hook execution should not change original input tensor.
* User can register hook for Tensor's gradient, It is expected
* that the hook only affects the gradient of the backward
* propagation, and does not affect the gradient value input
* as the hook.
* - use `tmp_ins_ptr`, only copy bwd_ins when the var in bwd_ins
* hold hooks
*/
auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type());
{
VLOG(3) << "Start to execute grad op " << cur_op.Type();
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
} else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, cur_op.Attrs(),
cur_op.place());
}
}
for (auto& pair : inplace_output_grad_var_list_) {
......@@ -428,15 +464,14 @@ void BasicEngine::Execute() {
if (!accumulator->SumGradCompleted()) {
continue;
}
// 1. Call Hooks for **inner_var_**
// 1. Call Hooks for `inner_var_`
accumulator->CallGradientHooks();
// 2. Sum Gradient with Previous Graph
// 2. Sum Gradient `inner_var_` to `var_` of Current or Previous Graph
accumulator->AccumulateGrad();
// 3. Call backward Hooks for **var_**
if (accumulator->HasPostHooks()) {
accumulator->CallBackwardPostHooks();
}
// 3. Call backward Hooks for `var_`
accumulator->CallReduceHooks();
}
need_accu_var_list_.clear();
......
......@@ -384,8 +384,8 @@ static platform::Place GetPlaceOfVar(
void GradientAccumulator::AccumulateGrad() {
/**
* If the gradient has been calculated by previous graph,
* it should be added to the previous graph result.
* If the leaf gradient has been calculated done, the inner_var_
* should be added to the var_.
*/
if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
return;
......@@ -396,7 +396,7 @@ void GradientAccumulator::AccumulateGrad() {
"this auto-grad"));
PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true,
platform::errors::InvalidArgument(
"Interior var of Leaf tensor should be initialized."));
"Interior var of Leaf tensor should be initialized."));
auto* src = inner_var_->MutableVar();
auto* dst = var_->MutableVar();
if (!var_->IsEmpty()) {
......@@ -427,10 +427,65 @@ void GradientAccumulator::AccumulateGrad() {
*(dst) = std::move(*src);
var_->SetType(inner_var_->Type());
var_->SetDataType(inner_var_->DataType());
var_->SetIsEmpty(false);
}
inner_var_.reset();
}
void GradientAccumulator::CallGradientHooks() {
PADDLE_ENFORCE_EQ(var_->IsLeafGrad(), true,
platform::errors::Unavailable(
"Only leaf gradient Tensor can deal with by gradient "
"hook in gradient accumulator."));
PADDLE_ENFORCE_EQ(
SumGradCompleted(), true,
platform::errors::PreconditionNotMet(
"Only can call gradient hooks after sum gradient completed."));
PADDLE_ENFORCE_EQ(
HasInnerVar(), true,
platform::errors::PreconditionNotMet(
"Leaf Tensor's inner var is nullptr when call gradient hook."));
PADDLE_ENFORCE_EQ(
inner_var_->Var().IsInitialized(), true,
platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
"is not initialized when "
"call gradient hook."));
if (var_->HasHook()) {
VLOG(3) << "Call " << var_->GetHooks().size()
<< " hooks of leaf gradient accumulator's inner var `"
<< var_->Name() << "`.";
auto tmp_var = inner_var_;
VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
<< var_->GetHooks().size();
for (const auto& hook_pair : var_->GetHooks()) {
tmp_var = (*hook_pair.second)(tmp_var);
}
inner_var_ = tmp_var;
}
}
void GradientAccumulator::CallReduceHooks() {
PADDLE_ENFORCE_EQ(
var_->IsLeafGrad(), true,
platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
"by reduce hook in gradient accumulator."));
PADDLE_ENFORCE_EQ(SumGradCompleted(), true,
platform::errors::PreconditionNotMet(
"Only can call reduce hooks after the gradient "
"summation is completed in current batch."));
PADDLE_ENFORCE_EQ(HasInnerVar(), false,
platform::errors::PreconditionNotMet(
"Only can call reduce hooks after the "
"gradient accumulation is completed in "
"current batch or across batchs."));
if (var_->HasMutableHook()) {
for (const auto& hook : var_->GetMutableHooks()) {
VLOG(3) << "call gradient accumulator backward hooks.";
(*hook)(var_);
}
}
}
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
size_t trace_id, bool unchange_input) {
/**
......
......@@ -40,8 +40,8 @@ class GradientAccumulator {
}
// inner_var_ record the grad of this auto-grad.
// Only need to generate inner var for non-empty leaf-tensor.
if (var->IsLeafGrad() && !var->IsEmpty()) {
// Only need to generate inner var for leaf-tensor.
if (var->IsLeafGrad()) {
inner_var_ = std::make_shared<VariableWrapper>(var->Name());
inner_var_->SetType(var->Type());
inner_var_->SetDataType(var->DataType());
......@@ -52,9 +52,6 @@ class GradientAccumulator {
<< ") to store result of this Graph";
}
// TODO(zhouwei): fix Tensor.clear_gradient() bug, remove this hard flag
var->SetIsEmpty(false);
// var_ is the final grad, processed by hooks and grad accumulation
var_ = var;
}
......@@ -93,42 +90,38 @@ class GradientAccumulator {
inline bool HasInnerVar() const { return inner_var_ != nullptr; }
/* Hook related methods */
inline bool HasPostHooks() const { return !post_hooks_.expired(); }
void SetPostHooks(const std::shared_ptr<LeafVarHookPipeline>& hooks) {
PADDLE_ENFORCE_NOT_NULL(
hooks, platform::errors::InvalidArgument(
"The hook set to GradientAccumulator is nullptr."));
auto shared_hooks = post_hooks_.lock();
if (shared_hooks != hooks) {
PADDLE_ENFORCE_EQ(
shared_hooks, nullptr,
platform::errors::PermissionDenied(
"Cannot set post hooks twice to GradientAccumulator."));
post_hooks_ = hooks;
}
}
// void CallHooks(){}
// ** inner_var_ **
// function that Sum Gradient with Previous Graph
void AccumulateGrad();
// call backward post hooks, such as reduce hook
void CallBackwardPostHooks() {
PADDLE_ENFORCE_NE(
post_hooks_.expired(), true,
platform::errors::NotFound(
"The post hooks of GradientAccumulator for Tensor `%s` expired.",
var_->Name()));
auto shared_hooks = post_hooks_.lock();
for (const auto& hook : shared_hooks->backward_hooks()) {
VLOG(3) << "call gradient accumulator backward hooks.";
(*hook)(var_);
}
}
/** [ Hook related methods ]
*
* [Why need two types of VariableWrapperHook? ]
*
* There are two types of gradient accumulation:
* 1. Gradient accumulation in same batch
* 2. Gradient accumulation across batchs
* The order of execution between Hooks and gradient accumulation:
* [ Gradient accumulation in same batch]
* |
* [ leaf GradVarBase hooks ]
* |
* [ Gradient accumulation across batchs ]
* |
* [ Gradient reduce / allreduce hooks ]
* Because we currently intend to accumulate these two gradient
* accumulation in one GradientAccumulator, We must distinguish between
* two types of hooks.
* And the InplaceVariableWrapperHook does not allow users to register
* directly, and is currently only used to support the reduce strategy of
* parallel multi-card training.
*/
void CallGradientHooks();
void CallReduceHooks();
protected:
VariableWrapper* var_;
......@@ -137,7 +130,6 @@ class GradientAccumulator {
std::shared_ptr<VariableWrapper> inner_var_;
size_t ref_cnt_{0};
size_t cur_cnt_{0};
std::weak_ptr<LeafVarHookPipeline> post_hooks_;
};
class EagerGradientAccumulator : public GradientAccumulator {
......
......@@ -18,100 +18,67 @@
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace imperative {
class VariableWrapper;
/** [ Basic hook classes ]
* s
* @brief OpBasePreHook is executed before the grad OpBase is executed,
/** [ Const VariableWrapper Hook: Pre hook functor of OpBase ]
*
* @brief This hook functor is executed before the grad OpBase is executed,
* taking the input of the current grad OpBase as input, and
* executing python hooks (user-defined) or C++ hooks (developer-defined)
* to achieve the purpose of custom operations on the interior VarBase
* gradient.
*
* @note OpBasePreHook will not change the input gradient VarBase.
* @note This hook functor will not change the input gradient VarBase.
*
* @note [Why need to be OpBase `PreHook`, why not `PostHook`?]
*
* If set OpBase post hook, when the op executed end, the op's output
* gradient may not be the final state, because it may need other op's
* gradient output to accumulated to it. But before op can be executed,
* the gradient output must have been accumulated to final value.
* 1. We expect If set OpBase post hook, when the op executed end, the
* op's output gradient may not be the final state, because it may need
* other op's gradient output to accumulated to it. But before op can
* be executed, the gradient output must have been accumulated to final
* value.
* 2. We don’t want the hook to change its input Tensor value, so now
* we can't call all hooks in GradAccumulator.
*
* @note [Why only can be used for interior VarBase?]
*
* Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
* GradVarBase has no next OpBase to executed, so if need to deal with
* the leaf GradVarBase, cannot use OpBasePreHook. For this case, we
* deal with by GradAccumulatorPostHook.
* the leaf GradVarBase, cannot use this hook functor. For this case, we
* deal with by other inplace hook method.
*/
class OpBasePreHook {
class VariableWrapperHook {
public:
virtual ~OpBasePreHook() = default;
virtual VariableWrapperList operator()(
const VariableWrapperList& grad_inputs) = 0;
virtual ~VariableWrapperHook() = default;
virtual std::shared_ptr<VariableWrapper> operator()(
const std::shared_ptr<VariableWrapper>& var) = 0;
};
/**
* @brief GradAccumulatorPostHook is the Hook that operates on the current
/** [ Inplace VariableWrapper Hook: Post hook functor of GradAccumulator ]
*
* @brief This hook functor is the Hook that operates on the current
* gradientafter the GradientAccumulator has accumulated the gradient.
* Leaf GradVarBase has no next OpBase, if we want to register hook
* for it, we also need to wait until the leaf GradVarBase accumulation
* is completed, so we can add post hook to GradientAccumulator.
*
* @note GradAccumulatorPostHook will change the grad VarBase value.
* @note This hook functor will change the grad VarBase value.
*
* @note Only allow leaf VarBase hold GradientAccumulatorPostHook.
* @note Only allow leaf VarBase hold call this hook functor.
*/
class GradAccumulatorPostHook {
class InplaceVariableWrapperHook {
public:
virtual ~GradAccumulatorPostHook() = default;
virtual ~InplaceVariableWrapperHook() = default;
virtual void operator()(VariableWrapper* var) = 0;
};
/** [ Hook for cpp functions ]
*
* Here we design three C++ hooks;
* 1. CppOpBasePreHook (Implement later):
* - used for developer-defined C++ interior VarBase hooks
* 2. CppGradAccumulatorPostHook (Implement later):
* - used for developer-defined C++ leaf VarBase hooks
* 3. LambdaGradAccumulatorPostHook:
* - used for VarBase reduce in parallel training
*
* @note [Why need two types of GradAccumulatorPostHook? ]
*
* There are two types of gradient accumulation:
* 1. Gradient accumulation in same batch
* 2. Gradient accumulation across batchs
* The order of execution between Hooks and gradient accumulation:
*
* [ Gradient accumulation in same batch]
* |
* [ leaf GradVarBase hooks ]
* |
* [ Gradient accumulation across batchs ]
* |
* [ Gradient reduce / allreduce]
*
* Because we currently intend to accumulate these two gradient
* accumulation in one GradientAccumulator, We must distinguish between
* two types of hooks.
*
* And the LambdaGradAccumulatorPostHook does not allow users to register
* directly, and is currently only used to support the reduce strategy of
* parallel multi-card training.
*/
class LambdaGradAccumulatorPostHook : public GradAccumulatorPostHook {
class LambdaInplaceVariableWrapperHook : public InplaceVariableWrapperHook {
public:
explicit LambdaGradAccumulatorPostHook(
std::function<void(VariableWrapper*)> fn)
explicit LambdaInplaceVariableWrapperHook(
std::function<void(VariableWrapper*)>&& fn)
: fn_(std::move(fn)) {}
void operator()(VariableWrapper* var) override { fn_(var); }
......@@ -120,114 +87,5 @@ class LambdaGradAccumulatorPostHook : public GradAccumulatorPostHook {
std::function<void(VariableWrapper*)> fn_;
};
/* Hooks for python function: in pybind/imperative.cc */
/** Add Python Hooks later:
* - PyOpBasePreHook (Implement later): used for user-defined interior python
* VarBase hooks
* - PyGradAccumulatorPostHook (Implement later): used for user-defined leaf
* python VarBase hooks
*/
/** [ Hook Pipeline classes ]
*
* @note [Why need hook pipeline classes?]
*
* There are 2 purposes for adding Hook pipeline here:
*
* 1. Make the code implementation cleaner.
*
* If there are no Hook pipeline, we need to add 3 hook vector into
* VariableWrapper, 1 hook vector into OpBase, 2 hook vector into
* GradientAccumulator, like:
*
* - VariableWrapper:
* std::vector<std::shared_ptr<OpBasePreHook>>
* interior_var_hooks_;
* std::vector<std::shared_ptr<GradAccumulatorPostHook>>
* leaf_var_hooks_;
* std::vector<std::shared_ptr<GradAccumulatorPostHook>>
* backward_hooks_;
*
* - OpBase:
* std::vector<std::weak_ptr<OpBasePreHook>>
* interior_var_hooks_;
*
* - GradientAccumulator:
* std::vector<std::weak_ptr<GradAccumulatorPostHook>>
* leaf_var_hooks_;
* std::vector<std::weak_ptr<GradAccumulatorPostHook>>
* backward_hooks_;
*
* This seems more complicated, and std::vector<std::weak_ptr<...>>
* is not easy to destruct.
*
* 2. Make the code easier to understand.
*
* From these two packages, we can clearly understand that we
* have two types of Hooks, respectively for the interior
* gradient var and leaf gradient var inside the backward
* calculation graph.
*/
class InteriorVarHookPipeline {
public:
InteriorVarHookPipeline() = default;
void add_hook(std::unique_ptr<OpBasePreHook>&& hook) {
hooks_.emplace_back(std::move(hook));
}
const std::vector<std::unique_ptr<OpBasePreHook>>& hooks() const {
return hooks_;
}
std::vector<std::unique_ptr<OpBasePreHook>>& hooks() { return hooks_; }
private:
std::vector<std::unique_ptr<OpBasePreHook>> hooks_;
DISABLE_COPY_AND_ASSIGN(InteriorVarHookPipeline);
};
class LeafVarHookPipeline {
public:
LeafVarHookPipeline() = default;
void add_hook(std::unique_ptr<GradAccumulatorPostHook>&& hook) {
hooks_.emplace_back(std::move(hook));
}
const std::vector<std::unique_ptr<GradAccumulatorPostHook>>& hooks() const {
return hooks_;
}
std::vector<std::unique_ptr<GradAccumulatorPostHook>>& hooks() {
return hooks_;
}
void add_backward_hook(std::unique_ptr<GradAccumulatorPostHook>&& hook) {
backward_hooks_.emplace_back(std::move(hook));
}
const std::vector<std::unique_ptr<GradAccumulatorPostHook>>& backward_hooks()
const {
return backward_hooks_;
}
std::vector<std::unique_ptr<GradAccumulatorPostHook>>& backward_hooks() {
return backward_hooks_;
}
private:
std::vector<std::unique_ptr<GradAccumulatorPostHook>> hooks_;
// NOTE: the `backward` here means the `whole backward process`,
// the `backward_hooks_` need to be executed after the `whole backward
// process`.
std::vector<std::unique_ptr<GradAccumulatorPostHook>> backward_hooks_;
DISABLE_COPY_AND_ASSIGN(LeafVarHookPipeline);
};
} // namespace imperative
} // namespace paddle
......@@ -30,6 +30,7 @@
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/flags.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/saved_variable_wrapper_list.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
......@@ -220,6 +221,26 @@ class VarBase {
void BumpInplaceVersion();
/* Hook related method: now only used for GradVarBase */
bool HasHook() const { return var_->HasHook(); }
int64_t AddHook(std::shared_ptr<VariableWrapperHook>&& hook) {
return var_->AddHook(
std::forward<std::shared_ptr<VariableWrapperHook>>(hook));
}
bool RemoveHook(const int64_t& hook_id) { return var_->RemoveHook(hook_id); }
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>& GetHooks()
const {
return var_->GetHooks();
}
void AddMutableHook(std::shared_ptr<InplaceVariableWrapperHook>&& hook) {
var_->AddMutableHook(
std::forward<std::shared_ptr<InplaceVariableWrapperHook>>(hook));
}
private:
/**
* NOTE(zengjinle): never remove the const qualifier of `var_` if you are
......
......@@ -177,8 +177,6 @@ class OpBase {
std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_;
size_t id_{-1UL};
std::weak_ptr<InteriorVarHookPipeline> pre_hooks_;
};
class GradOpNode {
......
......@@ -369,6 +369,10 @@ class GradientAccumulationInfo {
*is_finished = (cur_ref_cnt_ == total_ref_cnt_);
accumulator_->SumGrad(grad_var_partial, trace_id, unchange_input);
if (*is_finished && accumulator_->HasInnerVar()) {
accumulator_->AccumulateGrad();
}
if (create_graph_) {
VLOG(10) << "Store partial grad grad for double grad "
<< mapped_grad_var_->Name();
......
......@@ -310,11 +310,9 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
for (size_t global_var_index = 0; global_var_index < vars_.size();
++global_var_index) {
auto var = vars_[global_var_index];
var->SharedVar()->AddGradVarLeafBackwardHook(
std::unique_ptr<LambdaGradAccumulatorPostHook>(
new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) {
this->AddDistHook(global_var_index);
})));
var->GradVarBase()->AddMutableHook(
std::make_shared<LambdaInplaceVariableWrapperHook>([=](
VariableWrapper *grad) { this->AddDistHook(global_var_index); }));
var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index;
}
}
......
......@@ -74,16 +74,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
mul_attr_map["use_mkldnn"] = false;
// add GradAccumulatorPostHook
auto x_var_wrapper = x->SharedVar();
x_var_wrapper->AddGradVarLeafBackwardHook(
std::unique_ptr<LambdaGradAccumulatorPostHook>(
new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) {
x->GradVarBase()->AddMutableHook(
std::make_shared<LambdaInplaceVariableWrapperHook>(
[=](VariableWrapper* grad) {
auto* grad_tensor =
grad->MutableVar()->GetMutable<framework::LoDTensor>();
for (int i = 0; i < grad_tensor->numel(); ++i) {
grad_tensor->mutable_data<float>(place)[i] *= 2.0;
}
})));
}));
// 2. forward
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
......@@ -151,17 +150,16 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
memory::Copy(place, mutable_z, place, src_data.data(),
sizeof(float) * src_data.size());
// add GradAccumulatorPostHook
auto x_var_wrapper = x->SharedVar();
x_var_wrapper->AddGradVarLeafBackwardHook(
std::unique_ptr<LambdaGradAccumulatorPostHook>(
new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) {
// add ReduceBackwardHook
x->GradVarBase()->AddMutableHook(
std::make_shared<LambdaInplaceVariableWrapperHook>(
[=](VariableWrapper* grad) {
auto* grad_tensor =
grad->MutableVar()->GetMutable<framework::LoDTensor>();
for (int i = 0; i < grad_tensor->numel(); ++i) {
grad_tensor->mutable_data<float>(place)[i] *= 2.0;
}
})));
}));
// 2. forward
var_pair x_pair = var_pair("X", vb_vector(1, x));
......
......@@ -27,8 +27,8 @@
namespace paddle {
namespace imperative {
class InteriorVarHookPipeline;
class LeafVarHookPipeline;
class VariableWrapperHook;
class InplaceVariableWrapperHook;
class VarBase;
class GradOpNode;
......@@ -193,42 +193,6 @@ class VariableWrapper {
}
}
/* Hook related method: only can be call by GradVarBase */
bool HasInteriorHooks() const { return interior_hooks_ != nullptr; }
bool HasLeafHooks() const { return leaf_hooks_ != nullptr; }
void AddGradVarInteriorHook(std::unique_ptr<OpBasePreHook>&& hook) {
auto interior_hooks = GetGradVarInteriorHooksSafely();
interior_hooks->add_hook(std::move(hook));
}
void AddGradVarLeafHook(std::unique_ptr<GradAccumulatorPostHook>&& hook) {
auto leaf_hooks = GetGradVarLeafHooksSafely();
leaf_hooks->add_hook(std::move(hook));
}
void AddGradVarLeafBackwardHook(
std::unique_ptr<GradAccumulatorPostHook>&& hook) {
auto leaf_hooks = GetGradVarLeafHooksSafely();
leaf_hooks->add_backward_hook(std::move(hook));
}
const std::shared_ptr<InteriorVarHookPipeline>& GetInteriorHooks() const {
return interior_hooks_;
}
std::shared_ptr<InteriorVarHookPipeline>& GetInteriorHooks() {
return interior_hooks_;
}
const std::shared_ptr<LeafVarHookPipeline>& GetLeafHooks() const {
return leaf_hooks_;
}
std::shared_ptr<LeafVarHookPipeline>& GetLeafHooks() { return leaf_hooks_; }
uint32_t InplaceVersionSnapshot() const { return inplace_version_snapshot_; }
void ResetInplaceVersion() {
......@@ -255,6 +219,38 @@ class VariableWrapper {
return;
}
/* Hook related methods */
bool HasHook() const { return !hooks_.empty(); }
bool HasMutableHook() const { return !mutable_hooks_.empty(); }
int64_t AddHook(std::shared_ptr<VariableWrapperHook>&& hook) {
hooks_.emplace(next_hook_id_, std::move(hook));
return next_hook_id_++;
}
bool RemoveHook(const int64_t& hook_id) {
auto remove_cnt = hooks_.erase(hook_id);
if (remove_cnt == 0) {
return false;
}
return true;
}
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>& GetHooks()
const {
return hooks_;
}
void AddMutableHook(std::shared_ptr<InplaceVariableWrapperHook>&& hook) {
mutable_hooks_.emplace_back(std::move(hook));
}
const std::vector<std::shared_ptr<InplaceVariableWrapperHook>>&
GetMutableHooks() const {
return mutable_hooks_;
}
private:
void SetGradVar(const std::shared_ptr<VariableWrapper>& var) {
auto shared_var = grad_var_.lock();
......@@ -289,41 +285,6 @@ class VariableWrapper {
}
}
/* Hook related private methods */
std::shared_ptr<VariableWrapper> GetGradVarSafely() const {
auto shared_grad_var = grad_var_.lock();
PADDLE_ENFORCE_NOT_NULL(
shared_grad_var,
platform::errors::PermissionDenied(
"Cannot add gradient hook on Tensor without gradient."));
return shared_grad_var;
}
std::shared_ptr<InteriorVarHookPipeline>& GetGradVarInteriorHooksSafely() {
auto shared_grad_var = GetGradVarSafely();
PADDLE_ENFORCE_EQ(HasGradNode(), true,
platform::errors::PermissionDenied(
"Only interior Tensor in backward can register "
"interior gradient hook."));
if (shared_grad_var->interior_hooks_ == nullptr) {
shared_grad_var->interior_hooks_ =
std::make_shared<InteriorVarHookPipeline>();
}
return shared_grad_var->interior_hooks_;
}
std::shared_ptr<LeafVarHookPipeline>& GetGradVarLeafHooksSafely() {
auto shared_grad_var = GetGradVarSafely();
PADDLE_ENFORCE_EQ(
HasGradNode(), false,
platform::errors::PermissionDenied(
"Only leaf Tensor in backward can register leaf gradient hook."));
if (shared_grad_var->leaf_hooks_ == nullptr) {
shared_grad_var->leaf_hooks_ = std::make_shared<LeafVarHookPipeline>();
}
return shared_grad_var->leaf_hooks_;
}
private:
framework::Variable var_;
std::string name_;
......@@ -358,11 +319,14 @@ class VariableWrapper {
// isn't need
bool is_empty_{false};
// NOTE: only grad var can hold hooks now
// only interior var can hold interior hooks
std::shared_ptr<InteriorVarHookPipeline> interior_hooks_;
// only leaf var can hold leaf hooks
std::shared_ptr<LeafVarHookPipeline> leaf_hooks_;
// NOTE(chenweihang): only grad var can hold hooks now
int64_t next_hook_id_{0};
// Hooks used to register hook for grad var, support adding and removing,
// key is the accumulated int64_t value
std::map<int64_t, std::shared_ptr<VariableWrapperHook>> hooks_;
// Hooks executed after the execution of the entire backward process is over,
// currently only supported for reducing in distributed training
std::vector<std::shared_ptr<InplaceVariableWrapperHook>> mutable_hooks_;
};
} // namespace imperative
......
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/bkcl_context.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/imperative/partial_grad_engine.h"
......@@ -63,6 +64,65 @@ class Layer : public imperative::Layer {
}
};
template <typename T>
static T PyObjectCast(PyObject *obj) {
try {
return py::cast<T>(py::handle(obj));
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Python object is not type of %s", typeid(T).name()));
}
}
class PyVariableWrapperHook : public imperative::VariableWrapperHook {
public:
explicit PyVariableWrapperHook(PyObject *func) : py_func_(func) {
Py_INCREF(py_func_);
}
~PyVariableWrapperHook() {
py::gil_scoped_acquire gil;
Py_DECREF(py_func_);
}
std::shared_ptr<imperative::VariableWrapper> operator()(
const std::shared_ptr<imperative::VariableWrapper> &var) override {
py::gil_scoped_acquire gil;
VLOG(3) << "Call PyVariableWrapperHook for var " << var->Name();
// 1. unpack temp VarBase from VariableWrapper
std::shared_ptr<imperative::VarBase> tmp_varbase =
std::make_shared<imperative::VarBase>(var);
// 2. call hook and return
PyObject *res = nullptr;
try {
res = PyObject_CallFunctionObjArgs(py_func_, py::cast(tmp_varbase).ptr(),
nullptr);
} catch (platform::EnforceNotMet &e) {
throw std::move(e);
} catch (std::exception &e) {
PADDLE_THROW(platform::errors::Unavailable(
"Hook function of Tensor raises an exception: %s.", e.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Hook function of Tensor raises an unknown exception."));
}
PADDLE_ENFORCE_NOT_NULL(res,
platform::errors::Unavailable(
"Hook function of Tensor return a nullptr."));
if (res == Py_None) {
return var;
}
return PyObjectCast<std::shared_ptr<imperative::VarBase>>(res)->SharedVar();
}
private:
PyObject *py_func_;
};
static const platform::Place PyObjectToPlace(const py::object &place_obj) {
if (py::isinstance<platform::CPUPlace>(place_obj)) {
return place_obj.cast<platform::CPUPlace>();
......@@ -213,16 +273,6 @@ static std::string GetTypeName(const imperative::VarBase &var) {
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
template <typename T>
static T PyObjectCast(PyObject *obj) {
try {
return py::cast<T>(py::handle(obj));
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Python object is not type of %s", typeid(T).name()));
}
}
// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
......@@ -1023,6 +1073,23 @@ void BindImperative(py::module *m_ptr) {
}
},
py::call_guard<py::gil_scoped_release>())
.def("_register_grad_hook",
[](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ(
self.HasGradVar(), true,
platform::errors::InvalidArgument(
"Cannot register hook on a tensor without gradient."));
return self.GradVarBase()->AddHook(
std::make_shared<PyVariableWrapperHook>(hook.ptr()));
})
.def("_remove_grad_hook",
[](imperative::VarBase &self, int64_t hook_id) {
PADDLE_ENFORCE_EQ(
self.HasGradVar(), true,
platform::errors::InvalidArgument(
"Cannot remove hook on a tensor without gradient."));
return self.GradVarBase()->RemoveHook(hook_id);
})
.def("cpu",
[](const std::shared_ptr<imperative::VarBase> &self) {
if (platform::is_cpu_place(self->Place())) {
......@@ -1231,22 +1298,28 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable)
.def_property_readonly(
"shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) {
return framework::vectorize<int>(
self.Var().Get<framework::LoDTensor>().dims());
} else if (self.Var().IsType<framework::SelectedRows>()) {
return framework::vectorize<int>(
self.Var().Get<framework::SelectedRows>().value().dims());
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly("shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::LoDTensor>()
.dims());
} else if (self.Var()
.IsType<
framework::SelectedRows>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::SelectedRows>()
.value()
.dims());
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
R"DOC(
Whether a Tensor is leaf Tensor.
......
......@@ -14,6 +14,8 @@
import inspect
import numpy as np
import warnings
import weakref
import paddle
from .. import framework
......@@ -26,6 +28,34 @@ from .parallel import scale_loss
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
class TensorHookRemoveHelper(object):
"""
A helper class that for removing Tensor gradient's hook.
"""
def __init__(self, tensor, hook_id):
self._tensor_ref = weakref.ref(tensor)
self._hook_id = hook_id
def remove(self):
"""
Remove reference Tensor's hook.
Returns:
bool: Return True if removed successfully
"""
tensor = self._tensor_ref()
if tensor is not None:
res = tensor._remove_grad_hook(self._hook_id)
if res is True:
return True
else:
warnings.warn(
"The backward hook (ID: %d) of Tensor `%s` you want to remove does not exist or has been removed."
% (self._hook_id, tensor.name), RuntimeWarning)
return False
def monkey_patch_varbase():
@switch_to_static_graph
def _to_static_var(self, to_parameter=False, **kwargs):
......@@ -211,6 +241,73 @@ def monkey_patch_varbase():
else:
return np.array(new_ivar.value().get_tensor())
@framework.dygraph_only
def register_hook(self, hook):
"""
Registers a backward hook for current Tensor.
The hook will be called every time the gradient Tensor of current Tensor is computed.
The hook should not modify the input gradient Tensor, but it can optionally return
a new gradient Tensor which will be used in place of current Tensor's gradient.
The hook should have the following signature:
hook(grad) -> Tensor or None
Args:
hook(function): A backward hook to be registered for Tensor.grad
Returns:
TensorHookRemoveHelper: A helper object that can be used to remove the registered hook by calling `remove()` method.
Examples:
.. code-block:: python
import paddle
# hook function return None
def print_hook_fn(grad):
print(grad)
# hook function return Tensor
def double_hook_fn(grad):
grad = grad * 2
return grad
x = paddle.to_tensor([0., 1., 2., 3.], stop_gradient=False)
y = paddle.to_tensor([4., 5., 6., 7.], stop_gradient=False)
z = paddle.to_tensor([1., 2., 3., 4.])
# one Tensor can register multiple hooks
h = x.register_hook(print_hook_fn)
x.register_hook(double_hook_fn)
w = x + y
# register hook by lambda function
w.register_hook(lambda grad: grad * 2)
o = z.matmul(w)
o.backward()
# print_hook_fn print content in backward
# Tensor(shape=[4], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [2., 4., 6., 8.])
print("w.grad:", w.grad) # w.grad: [1. 2. 3. 4.]
print("x.grad:", x.grad) # x.grad: [ 4. 8. 12. 16.]
print("y.grad:", y.grad) # y.grad: [2. 4. 6. 8.]
# remove hook
h.remove()
"""
if self.stop_gradient is True:
raise RuntimeError(
"Cannot register hook on a tensor that stop gradient.")
hook_id = self._register_grad_hook(hook)
helper = TensorHookRemoveHelper(self, hook_id)
return helper
@property
def grad(self):
"""
......@@ -316,7 +413,8 @@ def monkey_patch_varbase():
("_to_static_var", _to_static_var), ("set_value", set_value),
("block", block), ("backward", backward), ("clear_grad", clear_grad),
("inplace_version", inplace_version), ("grad", grad),
("gradient", gradient), ("__str__", __str__), ("__repr__", __str__),
("gradient", gradient), ("register_hook", register_hook),
("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor")):
setattr(core.VarBase, method_name, method)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.nn as nn
class SimpleNet(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleNet, self).__init__()
self.linear1 = nn.Linear(in_size, in_size)
self.linear2 = nn.Linear(in_size, out_size)
def forward(self, x, hook=None, register=False, remove=False):
ret1 = self.linear1(x)
if hook is not None:
if register:
h = ret1.register_hook(hook)
if remove:
h.remove()
ret2 = self.linear2(ret1)
out = paddle.mean(ret2, axis=-1)
return ret1, out
class TestTensorRegisterHook(unittest.TestCase):
def setUp(self):
self.seed = 2021
self.in_size = 10
self.out_size = 10
self.batch_size = 4
self.devices = ["cpu"]
if paddle.is_compiled_with_cuda():
self.devices.append("gpu")
def test_hook_for_interior_var(self):
def run_double_hook_for_interior_var(double_hook, removed=False):
for device in self.devices:
paddle.set_device(device)
x = paddle.to_tensor([0., 1., 2., 3.])
y = paddle.to_tensor([4., 5., 6., 7.])
x.stop_gradient = False
y.stop_gradient = False
w = x + y
w.stop_gradient = False
helper = w.register_hook(double_hook)
z = paddle.to_tensor([1., 2., 3., 4.])
z.stop_gradient = False
o = z.matmul(w)
# remove hook before backward
if removed:
helper.remove()
o.backward()
# z.grad is not affected
self.assertTrue(np.array_equal(z.grad, w.numpy()))
# w.grad is not changed by hook
self.assertTrue(np.array_equal(w.grad, z.numpy()))
# x.grad and y.grad are changed if run hook
self.assertTrue(
np.array_equal(x.grad,
z.numpy() * 2 if not removed else z.numpy()))
self.assertTrue(
np.array_equal(y.grad,
z.numpy() * 2 if not removed else z.numpy()))
def run_print_hook_for_interior_var(print_hook, removed=False):
for device in self.devices:
paddle.set_device(device)
x = paddle.to_tensor([0., 1., 2., 3.])
y = paddle.to_tensor([4., 5., 6., 7.])
x.stop_gradient = False
y.stop_gradient = False
w = x + y
w.stop_gradient = False
helper = w.register_hook(print_hook)
z = paddle.to_tensor([1., 2., 3., 4.])
z.stop_gradient = False
o = z.matmul(w)
# remove hook before backward
if removed:
helper.remove()
o.backward()
# all grads are not affected
self.assertTrue(np.array_equal(z.grad, w.numpy()))
self.assertTrue(np.array_equal(w.grad, z.numpy()))
self.assertTrue(np.array_equal(x.grad, z.numpy()))
self.assertTrue(np.array_equal(y.grad, z.numpy()))
def double_hook(grad):
grad = grad * 2
print(grad)
return grad
def print_hook(grad):
print(grad)
# register hook
run_double_hook_for_interior_var(double_hook)
# register hook and removed
run_double_hook_for_interior_var(double_hook, removed=True)
# register hook
run_double_hook_for_interior_var(lambda grad: grad * 2)
# register hook and removed
run_double_hook_for_interior_var(lambda grad: grad * 2, removed=True)
# register hook
run_print_hook_for_interior_var(print_hook)
# register hook and removed
run_print_hook_for_interior_var(print_hook, removed=True)
def test_hook_for_leaf_var(self):
def run_double_hook_for_leaf_var(double_hook, removed=False):
for device in self.devices:
paddle.set_device(device)
x = paddle.to_tensor([0., 1., 2., 3.])
y = paddle.to_tensor([4., 5., 6., 7.])
x.stop_gradient = False
y.stop_gradient = False
helper = y.register_hook(double_hook)
w = x + y
w.stop_gradient = False
z = paddle.to_tensor([1., 2., 3., 4.])
z.stop_gradient = False
o = z.matmul(w)
# remove hook before backward
if removed:
helper.remove()
o.backward()
# z.grad, w.grad, x.grad is not affected
self.assertTrue(np.array_equal(z.grad, w.numpy()))
self.assertTrue(np.array_equal(w.grad, z.numpy()))
self.assertTrue(np.array_equal(x.grad, z.numpy()))
# y.grad are changed if run hook
self.assertTrue(
np.array_equal(y.grad,
z.numpy() * 2 if not removed else z.numpy()))
# register hook
run_double_hook_for_leaf_var(lambda grad: grad * 2)
# register hook and removed
run_double_hook_for_leaf_var(lambda grad: grad * 2, removed=True)
def test_hook_for_accumulated_grad(self):
def run_double_hook_for_accumulated_grad(double_hook, removed=False):
for device in self.devices:
paddle.set_device(device)
a = paddle.to_tensor([0., 1., 1., 2.])
b = paddle.to_tensor([0., 0., 1., 2.])
a.stop_gradient = False
b.stop_gradient = False
helper1 = a.register_hook(double_hook)
x = a + b
x.stop_gradient = False
helper2 = x.register_hook(double_hook)
y = paddle.to_tensor([4., 5., 6., 7.])
z = paddle.to_tensor([1., 2., 3., 4.])
y.stop_gradient = False
z.stop_gradient = False
o1 = x + y
o2 = x + z
o1.stop_gradient = False
o2.stop_gradient = False
o = o1.matmul(o2)
# remove hook before backward
if removed:
helper1.remove()
helper2.remove()
o.backward()
base_grad = np.array([5., 9., 13., 19.])
# x.grad is not changed
self.assertTrue(np.array_equal(x.grad, base_grad))
# b.grad is changed by x.hook
self.assertTrue(
np.array_equal(b.grad, base_grad * 2
if not removed else base_grad))
# a.grad is changed by x.hook and a.hook
self.assertTrue(
np.array_equal(a.grad, base_grad * 4
if not removed else base_grad))
# register hook
run_double_hook_for_accumulated_grad(lambda grad: grad * 2)
# register hook and removed
run_double_hook_for_accumulated_grad(
lambda grad: grad * 2, removed=True)
def test_hook_in_model(self):
def run_double_hook_in_model(data,
label,
hook=None,
register=False,
remove=False):
for device in self.devices:
paddle.seed(self.seed)
paddle.set_device(device)
net = SimpleNet(self.in_size, self.out_size)
loss_fn = nn.MSELoss()
data = paddle.to_tensor(data)
label = paddle.to_tensor(label)
ret1, out = net(data, hook, register, remove)
loss = loss_fn(out, label)
loss.backward()
return ret1.grad, net.linear1.weight.grad, net.linear1.bias.grad
data = np.random.uniform(
size=[self.batch_size, self.in_size]).astype('float32')
label = np.random.uniform(size=[self.batch_size, 1]).astype('float32')
# get original value
ret1_grad, linear1_w_grad, linear1_b_grad = run_double_hook_in_model(
data, label)
# get value changed by hook
ret1_grad_hook, linear1_w_grad_hook, linear1_b_grad_hook = run_double_hook_in_model(
data, label, lambda grad: grad * 2, True)
# get value after removing hook
ret1_grad_rm, linear1_w_grad_rm, linear1_b_grad_rm = run_double_hook_in_model(
data, label, lambda grad: grad * 2, True, True)
# compare original value and with hook
self.assertTrue(np.array_equal(ret1_grad, ret1_grad_hook))
self.assertTrue(np.array_equal(linear1_w_grad * 2, linear1_w_grad_hook))
self.assertTrue(np.array_equal(linear1_b_grad * 2, linear1_b_grad_hook))
# compare original value and remove hook
self.assertTrue(np.array_equal(ret1_grad, ret1_grad_rm))
self.assertTrue(np.array_equal(linear1_w_grad, linear1_w_grad_rm))
self.assertTrue(np.array_equal(linear1_b_grad, linear1_b_grad_rm))
def test_multiple_hooks_for_interior_var(self):
def run_multiple_hooks_for_interior_var(device,
hooks,
remove1=False,
remove2=False,
remove3=False):
paddle.set_device(device)
x = paddle.to_tensor([0., 1., 2., 3.])
y = paddle.to_tensor([4., 5., 6., 7.])
x.stop_gradient = False
y.stop_gradient = False
w = x + y
w.stop_gradient = False
helpers = []
for hook in hooks:
helper = w.register_hook(hook)
helpers.append(helper)
z = paddle.to_tensor([1., 2., 3., 4.])
z.stop_gradient = False
o = z.matmul(w)
if remove1:
helpers[0].remove()
if remove2:
helpers[1].remove()
if remove3:
helpers[2].remove()
o.backward()
return z.numpy(), w.grad, x.grad, y.grad
def double_hook(grad):
return grad * 2
hooks = [double_hook, double_hook, double_hook]
for device in self.devices:
z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var(
device, hooks)
self.assertTrue(np.array_equal(w_grad, z))
self.assertTrue(np.array_equal(x_grad, z * 8))
self.assertTrue(np.array_equal(y_grad, z * 8))
z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var(
device, hooks, remove1=True)
self.assertTrue(np.array_equal(w_grad, z))
self.assertTrue(np.array_equal(x_grad, z * 4))
self.assertTrue(np.array_equal(y_grad, z * 4))
z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var(
device, hooks, remove2=True)
self.assertTrue(np.array_equal(w_grad, z))
self.assertTrue(np.array_equal(x_grad, z * 4))
self.assertTrue(np.array_equal(y_grad, z * 4))
z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var(
device, hooks, remove3=True)
self.assertTrue(np.array_equal(w_grad, z))
self.assertTrue(np.array_equal(x_grad, z * 4))
self.assertTrue(np.array_equal(y_grad, z * 4))
z, w_grad, x_grad, y_grad = run_multiple_hooks_for_interior_var(
device, hooks, remove1=True, remove2=True, remove3=True)
self.assertTrue(np.array_equal(w_grad, z))
self.assertTrue(np.array_equal(x_grad, z))
self.assertTrue(np.array_equal(y_grad, z))
def test_hook_in_double_grad(self):
def double_print_hook(grad):
grad = grad * 2
print(grad)
return grad
x = paddle.ones(shape=[1], dtype='float32')
x.stop_gradient = False
# hook only works in backward
# for forward var x, the x.grad generated in
# paddle.grad will not deal with by hook
x.register_hook(double_print_hook)
y = x * x
# Since y = x * x, dx = 2 * x
dx = paddle.grad(
outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0]
z = y + dx
self.assertTrue(x.grad is None)
# If create_graph = True, the gradient of dx
# would be backpropagated. Therefore,
# z = x * x + dx = x * x + 2 * x, and
# x.gradient() = 2 * x + 2 = 4.0
# after changed by hook: 8.0
z.backward()
self.assertTrue(np.array_equal(x.grad, np.array([8.])))
def test_remove_one_hook_multiple_times(self):
for device in self.devices:
paddle.set_device(device)
x = paddle.to_tensor([1., 2., 3., 4.])
x.stop_gradient = False
h = x.register_hook(lambda grad: grad * 2)
self.assertTrue(h.remove())
self.assertFalse(h.remove())
def test_register_hook_for_stop_gradient_var(self):
for device in self.devices:
paddle.set_device(device)
x = paddle.to_tensor([1., 2., 3., 4.])
with self.assertRaises(RuntimeError):
x.register_hook(lambda grad: grad * 2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册