未验证 提交 7ba85aca 编写于 作者: C Chen Weihang 提交者: GitHub

Add inner register backward hook method for Tensor (#32171)

* add register backward hook method

* add leaf grad accumullated test
上级 f3e49c40
...@@ -284,15 +284,15 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks( ...@@ -284,15 +284,15 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
for (const auto& pair : bwd_ins) { for (const auto& pair : bwd_ins) {
for (size_t i = 0; i < pair.second.size(); ++i) { for (size_t i = 0; i < pair.second.size(); ++i) {
auto& var = pair.second[i]; auto& var = pair.second[i];
if (var->HasHook()) { if (var->HasVariableWrapperHook()) {
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VariableWrapper>>(bwd_ins); tmp_ins_ptr = std::make_shared<NameVarMap<VariableWrapper>>(bwd_ins);
} }
VLOG(3) << "Call " << var->GetHooks().size() << " hooks of " << op_type VLOG(3) << "Call " << var->GetVariableWrapperHooks().size()
<< "'s input `" << pair.first << "`'s var `" << var->Name() << " hooks of " << op_type << "'s input `" << pair.first
<< "`."; << "`'s var `" << var->Name() << "`.";
auto tmp_var = var; auto tmp_var = var;
for (const auto& hook_pair : var->GetHooks()) { for (const auto& hook_pair : var->GetVariableWrapperHooks()) {
tmp_var = (*hook_pair.second)(tmp_var); tmp_var = (*hook_pair.second)(tmp_var);
} }
(*tmp_ins_ptr)[pair.first][i] = tmp_var; (*tmp_ins_ptr)[pair.first][i] = tmp_var;
......
...@@ -467,14 +467,14 @@ void GradientAccumulator::CallGradientHooks() { ...@@ -467,14 +467,14 @@ void GradientAccumulator::CallGradientHooks() {
platform::errors::PreconditionNotMet("Leaf Tensor's inner var " platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
"is not initialized when " "is not initialized when "
"call gradient hook.")); "call gradient hook."));
if (var_->HasHook()) { if (var_->HasVariableWrapperHook()) {
VLOG(3) << "Call " << var_->GetHooks().size() VLOG(3) << "Call " << var_->GetVariableWrapperHooks().size()
<< " hooks of leaf gradient accumulator's inner var `" << " hooks of leaf gradient accumulator's inner var `"
<< var_->Name() << "`."; << var_->Name() << "`.";
auto tmp_var = inner_var_; auto tmp_var = inner_var_;
VLOG(3) << "Input var " << var_->Name() << "'s hook size - " VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
<< var_->GetHooks().size(); << var_->GetVariableWrapperHooks().size();
for (const auto& hook_pair : var_->GetHooks()) { for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
tmp_var = (*hook_pair.second)(tmp_var); tmp_var = (*hook_pair.second)(tmp_var);
} }
inner_var_ = tmp_var; inner_var_ = tmp_var;
...@@ -495,10 +495,10 @@ void GradientAccumulator::CallReduceHooks() { ...@@ -495,10 +495,10 @@ void GradientAccumulator::CallReduceHooks() {
"Only can call reduce hooks after the " "Only can call reduce hooks after the "
"gradient accumulation is completed in " "gradient accumulation is completed in "
"current batch or across batchs.")); "current batch or across batchs."));
if (var_->HasMutableHook()) { if (var_->HasVoidHook()) {
for (const auto& hook : var_->GetMutableHooks()) { for (const auto& hook : var_->GetVoidHooks()) {
VLOG(3) << "call gradient accumulator backward hooks."; VLOG(3) << "call gradient accumulator backward hooks.";
(*hook)(var_); (*hook)();
} }
} }
} }
......
...@@ -23,32 +23,34 @@ namespace imperative { ...@@ -23,32 +23,34 @@ namespace imperative {
class VariableWrapper; class VariableWrapper;
/** [ Const VariableWrapper Hook: Pre hook functor of OpBase ] /** [ VariableWrapper Hook ]
* *
* @brief This hook functor is executed before the grad OpBase is executed, * @brief This hook functor is executed before the grad OpBase is executed or
* taking the input of the current grad OpBase as input, and * after gradient accumulation completed in current batch.
* executing python hooks (user-defined) or C++ hooks (developer-defined) * 1. For interior var, VariableWrapper Hook take the input of the
* to achieve the purpose of custom operations on the interior VarBase * current grad OpBase as input.
* gradient. * 2. For leaf var, VariableWrapper Hook take the inner_var_ of
* GradientAccumulator as input.
* *
* @note This hook functor will not change the input gradient VarBase. * @note This hook functor will not change the input gradient VariableWrapper,
* but if you copy the input VariableWrapper and change the value of
* Variable in VariableWrapper, the value of input will also be changed,
* because they shared same PlaceHolder.
* *
* @note [Why need to be OpBase `PreHook`, why not `PostHook`?] * @note [ Why need to be OpBase `PreHook`, why not `PostHook`? ]
* *
* 1. We expect If set OpBase post hook, when the op executed end, the * We expect If set OpBase post hook, when the op executed end, the
* op's output gradient may not be the final state, because it may need * op's output gradient may not be the final state, because it may need
* other op's gradient output to accumulated to it. But before op can * other op's gradient output to accumulated to it. But before op can
* be executed, the gradient output must have been accumulated to final * be executed, the gradient output must have been accumulated to final
* value. * value.
* 2. We don’t want the hook to change its input Tensor value, so now
* we can't call all hooks in GradAccumulator.
* *
* @note [Why only can be used for interior VarBase?] * @note [ Why Leaf gradient is special? ]
* *
* Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf * Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
* GradVarBase has no next OpBase to executed, so if need to deal with * GradVarBase has no next OpBase to executed, so if need to deal with
* the leaf GradVarBase, cannot use this hook functor. For this case, we * the leaf GradVarBase, we should call hooks after gradient accumulation
* deal with by other inplace hook method. * completed.
*/ */
class VariableWrapperHook { class VariableWrapperHook {
public: public:
...@@ -57,34 +59,22 @@ class VariableWrapperHook { ...@@ -57,34 +59,22 @@ class VariableWrapperHook {
const std::shared_ptr<VariableWrapper>& var) = 0; const std::shared_ptr<VariableWrapper>& var) = 0;
}; };
/** [ Inplace VariableWrapper Hook: Post hook functor of GradAccumulator ] class CppVariableWrapperHook : public VariableWrapperHook {
*
* @brief This hook functor is the Hook that operates on the current
* gradientafter the GradientAccumulator has accumulated the gradient.
* Leaf GradVarBase has no next OpBase, if we want to register hook
* for it, we also need to wait until the leaf GradVarBase accumulation
* is completed, so we can add post hook to GradientAccumulator.
*
* @note This hook functor will change the grad VarBase value.
*
* @note Only allow leaf VarBase hold call this hook functor.
*/
class InplaceVariableWrapperHook {
public:
virtual ~InplaceVariableWrapperHook() = default;
virtual void operator()(VariableWrapper* var) = 0;
};
class LambdaInplaceVariableWrapperHook : public InplaceVariableWrapperHook {
public: public:
explicit LambdaInplaceVariableWrapperHook( explicit CppVariableWrapperHook(
std::function<void(VariableWrapper*)>&& fn) std::function<std::shared_ptr<VariableWrapper>(
const std::shared_ptr<VariableWrapper>&)>&& fn)
: fn_(std::move(fn)) {} : fn_(std::move(fn)) {}
void operator()(VariableWrapper* var) override { fn_(var); } std::shared_ptr<VariableWrapper> operator()(
const std::shared_ptr<VariableWrapper>& var) override {
return fn_(var);
}
private: private:
std::function<void(VariableWrapper*)> fn_; std::function<std::shared_ptr<VariableWrapper>(
const std::shared_ptr<VariableWrapper>&)>
fn_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -226,23 +226,25 @@ class VarBase { ...@@ -226,23 +226,25 @@ class VarBase {
void BumpInplaceVersion(); void BumpInplaceVersion();
/* Hook related method: now only used for GradVarBase */ /* Hook related method: now only used for GradVarBase */
bool HasHook() const { return var_->HasHook(); } bool HasVariableWrapperHook() const { return var_->HasVariableWrapperHook(); }
int64_t AddHook(std::shared_ptr<VariableWrapperHook>&& hook) { int64_t AddVariableWrapperHook(std::shared_ptr<VariableWrapperHook>&& hook) {
return var_->AddHook( return var_->AddVariableWrapperHook(
std::forward<std::shared_ptr<VariableWrapperHook>>(hook)); std::forward<std::shared_ptr<VariableWrapperHook>>(hook));
} }
bool RemoveHook(const int64_t& hook_id) { return var_->RemoveHook(hook_id); } bool RemoveVariableWrapperHook(const int64_t& hook_id) {
return var_->RemoveVariableWrapperHook(hook_id);
}
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>& GetHooks() const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>&
const { GetVariableWrapperHooks() const {
return var_->GetHooks(); return var_->GetVariableWrapperHooks();
} }
void AddMutableHook(std::shared_ptr<InplaceVariableWrapperHook>&& hook) { void AddVoidHook(std::shared_ptr<std::function<void()>>&& hook) {
var_->AddMutableHook( var_->AddVoidHook(
std::forward<std::shared_ptr<InplaceVariableWrapperHook>>(hook)); std::forward<std::shared_ptr<std::function<void()>>>(hook));
} }
private: private:
......
...@@ -310,9 +310,8 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, ...@@ -310,9 +310,8 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
for (size_t global_var_index = 0; global_var_index < vars_.size(); for (size_t global_var_index = 0; global_var_index < vars_.size();
++global_var_index) { ++global_var_index) {
auto var = vars_[global_var_index]; auto var = vars_[global_var_index];
var->GradVarBase()->AddMutableHook( var->GradVarBase()->AddVoidHook(std::make_shared<std::function<void()>>(
std::make_shared<LambdaInplaceVariableWrapperHook>([=]( [=]() { this->AddDistHook(global_var_index); }));
VariableWrapper *grad) { this->AddDistHook(global_var_index); }));
var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index; var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index;
} }
......
...@@ -37,6 +37,30 @@ namespace imperative { ...@@ -37,6 +37,30 @@ namespace imperative {
using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>; using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using var_pair = std::pair<std::string, vb_vector>; using var_pair = std::pair<std::string, vb_vector>;
std::shared_ptr<imperative::VariableWrapper> DoubleHook(
const std::shared_ptr<imperative::VariableWrapper>& var) {
// 1. create out var
auto out_var = std::make_shared<imperative::VariableWrapper>(var->Name());
out_var->SetType(var->Type());
out_var->SetDataType(var->DataType());
out_var->SetForwardDataType(var->ForwardDataType());
out_var->InnerSetOverridedStopGradient(var->InnerOverridedStopGradient());
// 2. get input and output var's tensor
auto* out_tensor = out_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& tensor = var->Var().Get<framework::LoDTensor>();
out_tensor->Resize(tensor.dims());
// 3. double calc
auto* data = tensor.data<float>();
auto* out_data = out_tensor->mutable_data<float>(platform::CPUPlace());
for (int64_t i = 0; i < out_tensor->numel(); ++i) {
out_data[i] = data[i] * 2.0;
}
return out_var;
}
TEST(TestHooks, TestGradVarLeafBackwardHook) { TEST(TestHooks, TestGradVarLeafBackwardHook) {
// 1. prepare // 1. prepare
Tracer tracer; Tracer tracer;
...@@ -73,16 +97,14 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) { ...@@ -73,16 +97,14 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
// add GradAccumulatorPostHook // add VariableWrapper hook
x->GradVarBase()->AddMutableHook( x->GradVarBase()->AddVariableWrapperHook(
std::make_shared<LambdaInplaceVariableWrapperHook>( std::make_shared<imperative::CppVariableWrapperHook>(DoubleHook));
[=](VariableWrapper* grad) {
auto* grad_tensor = // add Void hook
grad->MutableVar()->GetMutable<framework::LoDTensor>(); int64_t hook_value = 0;
for (int i = 0; i < grad_tensor->numel(); ++i) { x->GradVarBase()->AddVoidHook(
grad_tensor->mutable_data<float>(place)[i] *= 2.0; std::make_shared<std::function<void()>>([&]() { hook_value = 10; }));
}
}));
// 2. forward // 2. forward
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
...@@ -98,12 +120,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) { ...@@ -98,12 +120,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
engine.Init(tensors, grad_tensors); engine.Init(tensors, grad_tensors);
engine.Execute(); engine.Execute();
// verify VariableWrapper hook result
framework::LoDTensor x_grad; framework::LoDTensor x_grad;
framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place, framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place,
&x_grad); &x_grad);
for (int i = 0; i < x_grad.numel(); ++i) { for (int i = 0; i < x_grad.numel(); ++i) {
ASSERT_EQ(x_grad.data<float>()[i], 8.0); ASSERT_EQ(x_grad.data<float>()[i], 8.0);
} }
// verify Void hook result
ASSERT_EQ(hook_value, 10);
framework::LoDTensor y_grad; framework::LoDTensor y_grad;
framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place, framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place,
...@@ -152,16 +177,14 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() { ...@@ -152,16 +177,14 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
memory::Copy(place, mutable_z, place, src_data.data(), memory::Copy(place, mutable_z, place, src_data.data(),
sizeof(float) * src_data.size()); sizeof(float) * src_data.size());
// add ReduceBackwardHook // add VariableWrapper hook
x->GradVarBase()->AddMutableHook( x->GradVarBase()->AddVariableWrapperHook(
std::make_shared<LambdaInplaceVariableWrapperHook>( std::make_shared<imperative::CppVariableWrapperHook>(DoubleHook));
[=](VariableWrapper* grad) {
auto* grad_tensor = // add Void hook
grad->MutableVar()->GetMutable<framework::LoDTensor>(); int64_t hook_value = 0;
for (int i = 0; i < grad_tensor->numel(); ++i) { x->GradVarBase()->AddVoidHook(
grad_tensor->mutable_data<float>(place)[i] *= 2.0; std::make_shared<std::function<void()>>([&]() { hook_value = 100; }));
}
}));
// 2. forward // 2. forward
var_pair x_pair = var_pair("X", vb_vector(1, x)); var_pair x_pair = var_pair("X", vb_vector(1, x));
...@@ -199,12 +222,15 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() { ...@@ -199,12 +222,15 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
engine.Init(tensors, grad_tensors); engine.Init(tensors, grad_tensors);
engine.Execute(); engine.Execute();
// verify VariableWrapper hook result
framework::LoDTensor x_grad; framework::LoDTensor x_grad;
framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place, framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place,
&x_grad); &x_grad);
for (int i = 0; i < x_grad.numel(); ++i) { for (int i = 0; i < x_grad.numel(); ++i) {
ASSERT_EQ(x_grad.data<float>()[i], 16.0); ASSERT_EQ(x_grad.data<float>()[i], 16.0);
} }
// verify Void hook result
ASSERT_EQ(hook_value, 100);
framework::LoDTensor y_grad; framework::LoDTensor y_grad;
framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place, framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place,
......
...@@ -220,35 +220,35 @@ class VariableWrapper { ...@@ -220,35 +220,35 @@ class VariableWrapper {
} }
/* Hook related methods */ /* Hook related methods */
bool HasHook() const { return !hooks_.empty(); } bool HasVariableWrapperHook() const { return !var_hooks_.empty(); }
bool HasMutableHook() const { return !mutable_hooks_.empty(); } int64_t AddVariableWrapperHook(std::shared_ptr<VariableWrapperHook>&& hook) {
var_hooks_.emplace(next_hook_id_, std::move(hook));
int64_t AddHook(std::shared_ptr<VariableWrapperHook>&& hook) {
hooks_.emplace(next_hook_id_, std::move(hook));
return next_hook_id_++; return next_hook_id_++;
} }
bool RemoveHook(const int64_t& hook_id) { bool RemoveVariableWrapperHook(const int64_t& hook_id) {
auto remove_cnt = hooks_.erase(hook_id); auto remove_cnt = var_hooks_.erase(hook_id);
if (remove_cnt == 0) { if (remove_cnt == 0) {
return false; return false;
} }
return true; return true;
} }
const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>& GetHooks() const std::map<int64_t, std::shared_ptr<VariableWrapperHook>>&
const { GetVariableWrapperHooks() const {
return hooks_; return var_hooks_;
} }
void AddMutableHook(std::shared_ptr<InplaceVariableWrapperHook>&& hook) { bool HasVoidHook() const { return !void_hooks_.empty(); }
mutable_hooks_.emplace_back(std::move(hook));
void AddVoidHook(std::shared_ptr<std::function<void()>>&& hook) {
void_hooks_.emplace_back(std::move(hook));
} }
const std::vector<std::shared_ptr<InplaceVariableWrapperHook>>& const std::vector<std::shared_ptr<std::function<void()>>>& GetVoidHooks()
GetMutableHooks() const { const {
return mutable_hooks_; return void_hooks_;
} }
private: private:
...@@ -319,14 +319,19 @@ class VariableWrapper { ...@@ -319,14 +319,19 @@ class VariableWrapper {
// isn't need // isn't need
bool is_empty_{false}; bool is_empty_{false};
// NOTE(chenweihang): only grad var can hold hooks now // NOTE(chenweihang): only grad var will hold hooks now
int64_t next_hook_id_{0}; int64_t next_hook_id_{0};
// Hooks used to register hook for grad var, support adding and removing, // [ Hooks with VariableWrapper as input and output ]
// NOTE: Now registered for grad var, support adding and removing,
// key is the accumulated int64_t value // key is the accumulated int64_t value
std::map<int64_t, std::shared_ptr<VariableWrapperHook>> hooks_; // NOTE: Var hook need to support removing, so need hook id
// Hooks executed after the execution of the entire backward process is over, std::map<int64_t, std::shared_ptr<VariableWrapperHook>> var_hooks_;
// currently only supported for reducing in distributed training // [ Hooks without input and output ]
std::vector<std::shared_ptr<InplaceVariableWrapperHook>> mutable_hooks_; // NOTE: Now registered after the execution of the entire backward
// process is over, currently only used for reducing in distributed
// training
// NOTE: Now no need to support remove void hook
std::vector<std::shared_ptr<std::function<void()>>> void_hooks_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -1069,20 +1069,58 @@ void BindImperative(py::module *m_ptr) { ...@@ -1069,20 +1069,58 @@ void BindImperative(py::module *m_ptr) {
.def("_register_grad_hook", .def("_register_grad_hook",
[](imperative::VarBase &self, const py::handle &hook) { [](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
self.HasGradVar(), true, !self.OverridedStopGradient() && self.HasGradVar(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cannot register hook on a tensor without gradient.")); "Cannot register gradient hook on a Tensor that stop "
return self.GradVarBase()->AddHook( "gradient or without gradient."));
return self.GradVarBase()->AddVariableWrapperHook(
std::make_shared<PyVariableWrapperHook>(hook.ptr())); std::make_shared<PyVariableWrapperHook>(hook.ptr()));
}) })
.def("_remove_grad_hook", .def("_remove_grad_hook",
[](imperative::VarBase &self, int64_t hook_id) { [](imperative::VarBase &self, int64_t hook_id) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
self.HasGradVar(), true, !self.OverridedStopGradient() && self.HasGradVar(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cannot remove hook on a tensor without gradient.")); "Cannot remove gradient hook on a Tensor that stop "
return self.GradVarBase()->RemoveHook(hook_id); "gradient or without gradient."));
return self.GradVarBase()->RemoveVariableWrapperHook(hook_id);
}) })
.def("_register_backward_hook",
[](imperative::VarBase &self, const py::handle &hook) {
PADDLE_ENFORCE_EQ(
self.IsLeaf(), true,
platform::errors::InvalidArgument(
"Only can register backward hook for leaf Tensor."));
PADDLE_ENFORCE_EQ(
!self.OverridedStopGradient() && self.HasGradVar(), true,
platform::errors::InvalidArgument(
"Cannot register backward hook on a Tensor that stop "
"gradient or without gradient."));
auto py_func = PyObjectCast<std::function<void()>>(hook.ptr());
self.GradVarBase()->AddVoidHook(
std::make_shared<std::function<void()>>(py_func));
},
R"DOC(
Registers a backward hook for current Tensor.
This hook will be called every time the gradient of current Tensor has been fully calculated.
There are two differences with `_register_grad_hook`:
1. This backward hook will be executed after the gradient accumulation completed across batchs,
but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
completed in current batch.
2. This backward hook function should have the following signature:
hook() -> None
It requires no input and no return value.
Args:
hook(function): A backward hook to be registered for Tensor.gradient
Returns:
None
)DOC")
.def("cpu", .def("cpu",
[](const std::shared_ptr<imperative::VarBase> &self) { [](const std::shared_ptr<imperative::VarBase> &self) {
if (platform::is_cpu_place(self->Place())) { if (platform::is_cpu_place(self->Place())) {
...@@ -1301,21 +1339,15 @@ void BindImperative(py::module *m_ptr) { ...@@ -1301,21 +1339,15 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient) &imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable, .def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable) &imperative::VarBase::SetPersistable)
.def_property_readonly("shape", .def_property_readonly(
"shape",
[](imperative::VarBase &self) { [](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) { if (self.Var().IsType<framework::LoDTensor>()) {
return framework::vectorize<int>( return framework::vectorize<int>(
self.Var() self.Var().Get<framework::LoDTensor>().dims());
.Get<framework::LoDTensor>() } else if (self.Var().IsType<framework::SelectedRows>()) {
.dims());
} else if (self.Var()
.IsType<
framework::SelectedRows>()) {
return framework::vectorize<int>( return framework::vectorize<int>(
self.Var() self.Var().Get<framework::SelectedRows>().value().dims());
.Get<framework::SelectedRows>()
.value()
.dims());
} else { } else {
VLOG(2) << "It is meaningless to get shape of " VLOG(2) << "It is meaningless to get shape of "
"variable type " "variable type "
......
...@@ -178,8 +178,9 @@ class TestTensorRegisterHook(unittest.TestCase): ...@@ -178,8 +178,9 @@ class TestTensorRegisterHook(unittest.TestCase):
# register hook and removed # register hook and removed
run_double_hook_for_leaf_var(lambda grad: grad * 2, removed=True) run_double_hook_for_leaf_var(lambda grad: grad * 2, removed=True)
def test_hook_for_accumulated_grad(self): def test_hook_for_accumulated_grad_interior_var(self):
def run_double_hook_for_accumulated_grad(double_hook, removed=False): def run_double_hook_for_accumulated_grad_interior_var(double_hook,
removed=False):
for device in self.devices: for device in self.devices:
paddle.set_device(device) paddle.set_device(device)
...@@ -227,9 +228,50 @@ class TestTensorRegisterHook(unittest.TestCase): ...@@ -227,9 +228,50 @@ class TestTensorRegisterHook(unittest.TestCase):
if not removed else base_grad)) if not removed else base_grad))
# register hook # register hook
run_double_hook_for_accumulated_grad(lambda grad: grad * 2) run_double_hook_for_accumulated_grad_interior_var(lambda grad: grad * 2)
# register hook and removed # register hook and removed
run_double_hook_for_accumulated_grad( run_double_hook_for_accumulated_grad_interior_var(
lambda grad: grad * 2, removed=True)
def test_hook_for_accumulated_grad_leaf_var(self):
def run_double_hook_for_accumulated_grad_leaf_var(double_hook,
removed=False):
for device in self.devices:
paddle.set_device(device)
x = paddle.to_tensor([0., 1., 2., 4.])
x.stop_gradient = False
helper = x.register_hook(double_hook)
y = paddle.to_tensor([4., 5., 6., 7.])
z = paddle.to_tensor([1., 2., 3., 4.])
y.stop_gradient = False
z.stop_gradient = False
o1 = x + y
o2 = x + z
o1.stop_gradient = False
o2.stop_gradient = False
o = o1.matmul(o2)
# remove hook before backward
if removed:
helper.remove()
o.backward()
base_grad = np.array([5., 9., 13., 19.])
# x.grad is changed by x.hook
self.assertTrue(
np.array_equal(x.grad, base_grad * 2
if not removed else base_grad))
# register hook
run_double_hook_for_accumulated_grad_leaf_var(lambda grad: grad * 2)
# register hook and removed
run_double_hook_for_accumulated_grad_leaf_var(
lambda grad: grad * 2, removed=True) lambda grad: grad * 2, removed=True)
def test_hook_in_model(self): def test_hook_in_model(self):
...@@ -409,5 +451,54 @@ class TestTensorRegisterHook(unittest.TestCase): ...@@ -409,5 +451,54 @@ class TestTensorRegisterHook(unittest.TestCase):
x.register_hook(lambda grad: grad * 2) x.register_hook(lambda grad: grad * 2)
HOOK_INIT_VALUE = 10
HOOK_IS_CALLED = False
def global_void_hook():
global HOOK_INIT_VALUE
global HOOK_IS_CALLED
HOOK_INIT_VALUE *= 2
HOOK_IS_CALLED = True
class TestTensorRegisterBackwardHook(unittest.TestCase):
def setUp(self):
self.devices = ["cpu"]
if paddle.is_compiled_with_cuda():
self.devices.append("gpu")
def test_register_backward_hook(self):
global HOOK_INIT_VALUE
global HOOK_IS_CALLED
for device in self.devices:
x = paddle.to_tensor(5., stop_gradient=False)
x._register_backward_hook(global_void_hook)
for i in range(5):
y = paddle.pow(x, 4.0)
y.backward()
self.assertEqual(HOOK_INIT_VALUE, 320)
self.assertTrue(HOOK_IS_CALLED)
# reset initial value
HOOK_INIT_VALUE = 10
HOOK_IS_CALLED = False
def test_register_backward_hook_for_interior_var(self):
x = paddle.to_tensor(5., stop_gradient=False)
y = paddle.pow(x, 4.0)
with self.assertRaises(ValueError):
y._register_backward_hook(global_void_hook)
def test_register_backward_hook_for_var_without_gradient(self):
x = paddle.to_tensor(5.)
y = paddle.pow(x, 4.0)
with self.assertRaises(ValueError):
x._register_backward_hook(global_void_hook)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册