From 7eeb99fe025c0946014956300930461bf3ad8fe9 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 18 Nov 2020 13:09:21 +0800 Subject: [PATCH] Add basic hook classes for dygraph & implement reduce hook (#28584) * add base hook classes and reduce hook impl * fix constructor typo * polish comment format * refactor baisc hook class design * polish design details --- paddle/fluid/imperative/basic_engine.cc | 21 +- paddle/fluid/imperative/basic_engine.h | 3 + .../fluid/imperative/gradient_accumulator.cc | 9 +- .../fluid/imperative/gradient_accumulator.h | 49 ++++ paddle/fluid/imperative/hooks.h | 233 +++++++++++++++++ paddle/fluid/imperative/op_base.h | 2 +- paddle/fluid/imperative/tests/CMakeLists.txt | 1 + paddle/fluid/imperative/tests/test_hooks.cc | 240 ++++++++++++++++++ paddle/fluid/imperative/variable_wrapper.h | 82 ++++++ 9 files changed, 637 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/imperative/hooks.h create mode 100644 paddle/fluid/imperative/tests/test_hooks.cc diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 9ad30506b2..e9214a8fea 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -114,6 +114,16 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) { accumulator->IncreaseRefCnt(); + if (var->HasLeafHooks()) { + VLOG(3) << "Grad variable wrapper (" << var->Name() + << ") has leaf grad hooks."; + PADDLE_ENFORCE_NE(var->HasGradNode(), true, + platform::errors::PermissionDenied( + "Only leaf Tensor's gradient can append hook to " + "Gradientaccumulator.")); + accumulator->SetPostHooks(var->GetLeafHooks()); + } + VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" << var.get() << ") with reference count " << accumulator->RefCnt(); @@ -204,6 +214,7 @@ void BasicEngine::Execute() { var->Name())); if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) { + no_need_run_accumulators_.emplace_back(iter->second.get()); continue; } @@ -220,12 +231,19 @@ void BasicEngine::Execute() { cur_op.place()); } - // Step 2: Sum Gradient + // Step 2: Sum Gradient & Call Accumulator Hooks + for (auto* accumulator : no_need_run_accumulators_) { + if (accumulator->HasPostHooks()) { + accumulator->CallBackwardPostHooks(); + } + } + for (auto& pair : need_accu_var_list_) { pair.first->Add(std::move(pair.second), cur_op.id()); } need_accu_var_list_.clear(); + no_need_run_accumulators_.clear(); VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; if (!retain_graph_) { @@ -258,6 +276,7 @@ void BasicEngine::Clear() { node_deps_.clear(); accumulators_.clear(); need_accu_var_list_.clear(); + no_need_run_accumulators_.clear(); } } // namespace imperative diff --git a/paddle/fluid/imperative/basic_engine.h b/paddle/fluid/imperative/basic_engine.h index 0906dd4f92..92e7fe7eb8 100644 --- a/paddle/fluid/imperative/basic_engine.h +++ b/paddle/fluid/imperative/basic_engine.h @@ -49,6 +49,9 @@ class BasicEngine : public Engine { accumulators_; std::vector>> need_accu_var_list_; + // Accumulators that does not need to perform accumulation operations, + // the ref_cnt_=1, corresponding to need_accu_var_list_ + std::vector no_need_run_accumulators_; bool retain_graph_; }; diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 07f1868b7f..00fd18e5e2 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -401,13 +401,15 @@ void EagerGradientAccumulator::Add(std::shared_ptr var, } } } - ++cur_cnt_; if (var_->Var().IsType()) { var_->SetType(framework::proto::VarType::LOD_TENSOR); } else if (var_->Var().IsType()) { var_->SetType(framework::proto::VarType::SELECTED_ROWS); } + + // Increase count & call post hooks + IncreaseCurCnt(); } void SortedGradientAccumulator::Add(std::shared_ptr var, @@ -520,6 +522,11 @@ void SortedGradientAccumulator::Add(std::shared_ptr var, } else if (var_->Var().IsType()) { var_->SetType(framework::proto::VarType::SELECTED_ROWS); } + + // call post hooks + if (HasPostHooks()) { + CallBackwardPostHooks(); + } } } // namespace imperative diff --git a/paddle/fluid/imperative/gradient_accumulator.h b/paddle/fluid/imperative/gradient_accumulator.h index a8ccb2a38d..2d0cc6e892 100644 --- a/paddle/fluid/imperative/gradient_accumulator.h +++ b/paddle/fluid/imperative/gradient_accumulator.h @@ -17,6 +17,8 @@ #include #include #include + +#include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/layer.h" namespace paddle { @@ -35,9 +37,43 @@ class GradientAccumulator { inline size_t RefCnt() const { return ref_cnt_; } + /* Hook related methods */ + inline bool HasPostHooks() const { return !post_hooks_.expired(); } + + void SetPostHooks(const std::shared_ptr& hooks) { + PADDLE_ENFORCE_NOT_NULL( + hooks, platform::errors::InvalidArgument( + "The hook set to GradientAccumulator is nullptr.")); + + auto shared_hooks = post_hooks_.lock(); + if (shared_hooks != hooks) { + PADDLE_ENFORCE_EQ( + shared_hooks, nullptr, + platform::errors::PermissionDenied( + "Cannot set post hooks twice to GradientAccumulator.")); + post_hooks_ = hooks; + } + } + + // call backward post hooks, such as reduce hook + void CallBackwardPostHooks() { + PADDLE_ENFORCE_NE( + post_hooks_.expired(), true, + platform::errors::NotFound( + "The post hooks of GradientAccumulator for Tensor `%s` expired.", + var_->Name())); + auto shared_hooks = post_hooks_.lock(); + for (const auto& hook : shared_hooks->backward_hooks()) { + VLOG(3) << "call gradient accumulator backward hooks."; + (*hook)(var_); + } + } + protected: VariableWrapper* var_; size_t ref_cnt_{0}; + + std::weak_ptr post_hooks_; }; class EagerGradientAccumulator : public GradientAccumulator { @@ -47,6 +83,19 @@ class EagerGradientAccumulator : public GradientAccumulator { void Add(std::shared_ptr var, size_t trace_id, bool unchange_input) override; + private: + inline bool AccumulateCompleted() const { return cur_cnt_ == ref_cnt_; } + + void IncreaseCurCnt() { + ++cur_cnt_; + VLOG(3) << "IncreaseCurCnt: cur_cnt " << cur_cnt_ << ", ref_cnt " + << ref_cnt_; + // After all tmp gradient being accumulated to grad var, run hooks + if (AccumulateCompleted() && HasPostHooks()) { + CallBackwardPostHooks(); + } + } + private: size_t cur_cnt_{0}; }; diff --git a/paddle/fluid/imperative/hooks.h b/paddle/fluid/imperative/hooks.h new file mode 100644 index 0000000000..1211ec6ae6 --- /dev/null +++ b/paddle/fluid/imperative/hooks.h @@ -0,0 +1,233 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace imperative { + +class VariableWrapper; + +/** [ Basic hook classes ] + * s + * @brief OpBasePreHook is executed before the grad OpBase is executed, + * taking the input of the current grad OpBase as input, and + * executing python hooks (user-defined) or C++ hooks (developer-defined) + * to achieve the purpose of custom operations on the interior VarBase + * gradient. + * + * @note OpBasePreHook will not change the input gradient VarBase. + * + * @note [Why need to be OpBase `PreHook`, why not `PostHook`?] + * + * If set OpBase post hook, when the op executed end, the op's output + * gradient may not be the final state, because it may need other op's + * gradient output to accumulated to it. But before op can be executed, + * the gradient output must have been accumulated to final value. + * + * @note [Why only can be used for interior VarBase?] + * + * Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf + * GradVarBase has no next OpBase to executed, so if need to deal with + * the leaf GradVarBase, cannot use OpBasePreHook. For this case, we + * deal with by GradAccumulatorPostHook. + */ +class OpBasePreHook { + public: + virtual ~OpBasePreHook() = default; + virtual VariableWrapperList operator()( + const VariableWrapperList& grad_inputs) = 0; +}; + +/** + * @brief GradAccumulatorPostHook is the Hook that operates on the current + * gradientafter the GradientAccumulator has accumulated the gradient. + * Leaf GradVarBase has no next OpBase, if we want to register hook + * for it, we also need to wait until the leaf GradVarBase accumulation + * is completed, so we can add post hook to GradientAccumulator. + * + * @note GradAccumulatorPostHook will change the grad VarBase value. + * + * @note Only allow leaf VarBase hold GradientAccumulatorPostHook. + */ +class GradAccumulatorPostHook { + public: + virtual ~GradAccumulatorPostHook() = default; + virtual void operator()(VariableWrapper* var) = 0; +}; + +/** [ Hook for cpp functions ] + * + * Here we design three C++ hooks; + * 1. CppOpBasePreHook (Implement later): + * - used for developer-defined C++ interior VarBase hooks + * 2. CppGradAccumulatorPostHook (Implement later): + * - used for developer-defined C++ leaf VarBase hooks + * 3. LambdaGradAccumulatorPostHook: + * - used for VarBase reduce in parallel training + * + * @note [Why need two types of GradAccumulatorPostHook? ] + * + * There are two types of gradient accumulation: + * 1. Gradient accumulation in same batch + * 2. Gradient accumulation across batchs + * The order of execution between Hooks and gradient accumulation: + * + * [ Gradient accumulation in same batch] + * | + * [ leaf GradVarBase hooks ] + * | + * [ Gradient accumulation across batchs ] + * | + * [ Gradient reduce / allreduce] + * + * Because we currently intend to accumulate these two gradient + * accumulation in one GradientAccumulator, We must distinguish between + * two types of hooks. + * + * And the LambdaGradAccumulatorPostHook does not allow users to register + * directly, and is currently only used to support the reduce strategy of + * parallel multi-card training. + */ +class LambdaGradAccumulatorPostHook : public GradAccumulatorPostHook { + public: + explicit LambdaGradAccumulatorPostHook( + std::function fn) + : fn_(std::move(fn)) {} + + void operator()(VariableWrapper* var) override { fn_(var); } + + private: + std::function fn_; +}; + +/* Hooks for python function: in pybind/imperative.cc */ + +/** Add Python Hooks later: + * - PyOpBasePreHook (Implement later): used for user-defined interior python + * VarBase hooks + * - PyGradAccumulatorPostHook (Implement later): used for user-defined leaf + * python VarBase hooks + */ + +/** [ Hook Pipeline classes ] + * + * @note [Why need hook pipeline classes?] + * + * There are 2 purposes for adding Hook pipeline here: + * + * 1. Make the code implementation cleaner. + * + * If there are no Hook pipeline, we need to add 3 hook vector into + * VariableWrapper, 1 hook vector into OpBase, 2 hook vector into + * GradientAccumulator, like: + * + * - VariableWrapper: + * std::vector> + * interior_var_hooks_; + * std::vector> + * leaf_var_hooks_; + * std::vector> + * backward_hooks_; + * + * - OpBase: + * std::vector> + * interior_var_hooks_; + * + * - GradientAccumulator: + * std::vector> + * leaf_var_hooks_; + * std::vector> + * backward_hooks_; + * + * This seems more complicated, and std::vector> + * is not easy to destruct. + * + * 2. Make the code easier to understand. + * + * From these two packages, we can clearly understand that we + * have two types of Hooks, respectively for the interior + * gradient var and leaf gradient var inside the backward + * calculation graph. + */ + +class InteriorVarHookPipeline { + public: + InteriorVarHookPipeline() = default; + + void add_hook(std::unique_ptr&& hook) { + hooks_.emplace_back(std::move(hook)); + } + + const std::vector>& hooks() const { + return hooks_; + } + + std::vector>& hooks() { return hooks_; } + + private: + std::vector> hooks_; + + DISABLE_COPY_AND_ASSIGN(InteriorVarHookPipeline); +}; + +class LeafVarHookPipeline { + public: + LeafVarHookPipeline() = default; + + void add_hook(std::unique_ptr&& hook) { + hooks_.emplace_back(std::move(hook)); + } + + const std::vector>& hooks() const { + return hooks_; + } + + std::vector>& hooks() { + return hooks_; + } + + void add_backward_hook(std::unique_ptr&& hook) { + backward_hooks_.emplace_back(std::move(hook)); + } + + const std::vector>& backward_hooks() + const { + return backward_hooks_; + } + + std::vector>& backward_hooks() { + return backward_hooks_; + } + + private: + std::vector> hooks_; + // NOTE: the `backward` here means the `whole backward process`, + // the `backward_hooks_` need to be executed after the `whole backward + // process`. + std::vector> backward_hooks_; + + DISABLE_COPY_AND_ASSIGN(LeafVarHookPipeline); +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index a4b57c404c..36185af3a2 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -176,7 +176,7 @@ class OpBase { platform::Place place_; size_t id_{-1UL}; - std::vector> backward_hooks_; + std::weak_ptr pre_hooks_; }; class GradOpNode { diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index e3c82474e0..a8de1e6b03 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -11,3 +11,4 @@ cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy se cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) +cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy) diff --git a/paddle/fluid/imperative/tests/test_hooks.cc b/paddle/fluid/imperative/tests/test_hooks.cc new file mode 100644 index 0000000000..7bf5f87668 --- /dev/null +++ b/paddle/fluid/imperative/tests/test_hooks.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/basic_engine.h" +#include "paddle/fluid/imperative/hooks.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/memory/memcpy.h" + +namespace platform = paddle::platform; +namespace framework = paddle::framework; +namespace memory = paddle::memory; + +DECLARE_bool(sort_sum_gradient); + +namespace paddle { +namespace imperative { + +using vb_vector = std::vector>; +using var_pair = std::pair; + +TEST(TestHooks, TestGradVarLeafBackwardHook) { + // 1. prepare + Tracer tracer; + std::shared_ptr x(new VarBase(true, "x")); + std::shared_ptr y(new VarBase(true, "y")); + std::shared_ptr out(new VarBase(true, "out")); + x->SetOverridedStopGradient(false); + y->SetOverridedStopGradient(false); + + platform::CPUPlace place; + std::vector src_data(10, 2.0); + std::vector x_dims = {2, 5}; + std::vector y_dims = {5, 2}; + + auto* x_tensor = x->MutableVar()->GetMutable(); + auto* y_tensor = y->MutableVar()->GetMutable(); + + x_tensor->Resize(framework::make_ddim(x_dims)); + auto* mutable_x = x_tensor->mutable_data(place); + memory::Copy(place, mutable_x, place, src_data.data(), + sizeof(float) * src_data.size()); + + y_tensor->Resize(framework::make_ddim(y_dims)); + auto* mutable_y = y_tensor->mutable_data(place); + memory::Copy(place, mutable_y, place, src_data.data(), + sizeof(float) * src_data.size()); + + var_pair x_pair = var_pair("X", vb_vector(1, x)); + var_pair y_pair = var_pair("Y", vb_vector(1, y)); + var_pair out_pair = var_pair("Out", vb_vector(1, out)); + + NameVarBaseMap ins = {x_pair, y_pair}; + NameVarBaseMap outs = {out_pair}; + framework::AttributeMap mul_attr_map; + mul_attr_map["use_mkldnn"] = false; + + // add GradAccumulatorPostHook + auto x_var_wrapper = x->SharedVar(); + x_var_wrapper->AddGradVarLeafBackwardHook( + std::unique_ptr( + new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) { + auto* grad_tensor = + grad->MutableVar()->GetMutable(); + for (int i = 0; i < grad_tensor->numel(); ++i) { + grad_tensor->mutable_data(place)[i] *= 2.0; + } + }))); + + // 2. forward + tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); + + ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(out->GradVarBase()->GradOpNum(), 1UL); + + // 3. backward + BasicEngine engine; + engine.Init(out.get()); + engine.Execute(); + + framework::LoDTensor x_grad; + framework::TensorCopySync(x->GradVar().Get(), place, + &x_grad); + for (int i = 0; i < x_grad.numel(); ++i) { + ASSERT_EQ(x_grad.data()[i], 8.0); + } + + framework::LoDTensor y_grad; + framework::TensorCopySync(y->GradVar().Get(), place, + &y_grad); + + for (int i = 0; i < y_grad.numel(); ++i) { + ASSERT_EQ(y_grad.data()[i], 4.0); + } +} + +void GradVarLeafBackwardHookWithGradAccmulatedTest() { + // 1. prepare + Tracer tracer; + std::shared_ptr x(new VarBase(true, "x")); + std::shared_ptr y(new VarBase(true, "y")); + std::shared_ptr z(new VarBase(true, "z")); + std::shared_ptr out_xy(new VarBase(true, "out_xy")); + std::shared_ptr out_xz(new VarBase(true, "out_xz")); + std::shared_ptr out(new VarBase(true, "out")); + x->SetOverridedStopGradient(false); + y->SetOverridedStopGradient(false); + z->SetOverridedStopGradient(false); + + platform::CPUPlace place; + std::vector src_data(10, 2.0); + std::vector x_dims = {2, 5}; + std::vector y_dims = {5, 2}; + std::vector z_dims = {5, 2}; + + auto* x_tensor = x->MutableVar()->GetMutable(); + auto* y_tensor = y->MutableVar()->GetMutable(); + auto* z_tensor = z->MutableVar()->GetMutable(); + + x_tensor->Resize(framework::make_ddim(x_dims)); + auto* mutable_x = x_tensor->mutable_data(place); + memory::Copy(place, mutable_x, place, src_data.data(), + sizeof(float) * src_data.size()); + + y_tensor->Resize(framework::make_ddim(y_dims)); + auto* mutable_y = y_tensor->mutable_data(place); + memory::Copy(place, mutable_y, place, src_data.data(), + sizeof(float) * src_data.size()); + + z_tensor->Resize(framework::make_ddim(z_dims)); + auto* mutable_z = z_tensor->mutable_data(place); + memory::Copy(place, mutable_z, place, src_data.data(), + sizeof(float) * src_data.size()); + + // add GradAccumulatorPostHook + auto x_var_wrapper = x->SharedVar(); + x_var_wrapper->AddGradVarLeafBackwardHook( + std::unique_ptr( + new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) { + auto* grad_tensor = + grad->MutableVar()->GetMutable(); + for (int i = 0; i < grad_tensor->numel(); ++i) { + grad_tensor->mutable_data(place)[i] *= 2.0; + } + }))); + + // 2. forward + var_pair x_pair = var_pair("X", vb_vector(1, x)); + var_pair y_pair = var_pair("Y", vb_vector(1, y)); + var_pair out_xy_pair = var_pair("Out", vb_vector(1, out_xy)); + NameVarBaseMap ins = {x_pair, y_pair}; + NameVarBaseMap outs = {out_xy_pair}; + framework::AttributeMap mul_attr_map; + mul_attr_map["use_mkldnn"] = false; + tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); + + var_pair z_pair = var_pair("Y", vb_vector(1, z)); + var_pair out_xz_pair = var_pair("Out", vb_vector(1, out_xz)); + ins = {x_pair, z_pair}; + outs = {out_xz_pair}; + tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); + + var_pair xy_pair = var_pair("X", vb_vector(1, out_xy)); + var_pair xz_pair = var_pair("Y", vb_vector(1, out_xz)); + var_pair out_pair = var_pair("Out", vb_vector(1, out)); + ins = {xy_pair, xz_pair}; + outs = {out_pair}; + framework::AttributeMap add_attr_map; + tracer.TraceOp("elementwise_add", ins, outs, add_attr_map, place, true); + + ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(z->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(out->GradVarBase()->GradOpNum(), 1UL); + + // 3. backward + BasicEngine engine; + engine.Init(out.get()); + engine.Execute(); + + framework::LoDTensor x_grad; + framework::TensorCopySync(x->GradVar().Get(), place, + &x_grad); + for (int i = 0; i < x_grad.numel(); ++i) { + ASSERT_EQ(x_grad.data()[i], 16.0); + } + + framework::LoDTensor y_grad; + framework::TensorCopySync(y->GradVar().Get(), place, + &y_grad); + + for (int i = 0; i < y_grad.numel(); ++i) { + ASSERT_EQ(y_grad.data()[i], 4.0); + } + + framework::LoDTensor z_grad; + framework::TensorCopySync(z->GradVar().Get(), place, + &z_grad); + + for (int i = 0; i < z_grad.numel(); ++i) { + ASSERT_EQ(z_grad.data()[i], 4.0); + } +} + +TEST(TestHooks, TestGradVarLeafBackwardHookWithGradAccmulated) { + GradVarLeafBackwardHookWithGradAccmulatedTest(); +} + +TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) { + FLAGS_sort_sum_gradient = true; + GradVarLeafBackwardHookWithGradAccmulatedTest(); + FLAGS_sort_sum_gradient = false; +} + +} // namespace imperative +} // namespace paddle + +USE_OP(mul); +USE_OP(mul_grad); +USE_OP(elementwise_add); +USE_OP(elementwise_add_grad); diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index d730ddc12d..e9b1ccc860 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -16,11 +16,16 @@ #include #include +#include + #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/imperative/hooks.h" namespace paddle { namespace imperative { +class InteriorVarHookPipeline; +class LeafVarHookPipeline; class VarBase; class GradOpNode; @@ -133,6 +138,42 @@ class VariableWrapper { } } + /* Hook related method: only can be call by GradVarBase */ + + bool HasInteriorHooks() const { return interior_hooks_ != nullptr; } + + bool HasLeafHooks() const { return leaf_hooks_ != nullptr; } + + void AddGradVarInteriorHook(std::unique_ptr&& hook) { + auto interior_hooks = GetGradVarInteriorHooksSafely(); + interior_hooks->add_hook(std::move(hook)); + } + + void AddGradVarLeafHook(std::unique_ptr&& hook) { + auto leaf_hooks = GetGradVarLeafHooksSafely(); + leaf_hooks->add_hook(std::move(hook)); + } + + void AddGradVarLeafBackwardHook( + std::unique_ptr&& hook) { + auto leaf_hooks = GetGradVarLeafHooksSafely(); + leaf_hooks->add_backward_hook(std::move(hook)); + } + + const std::shared_ptr& GetInteriorHooks() const { + return interior_hooks_; + } + + std::shared_ptr& GetInteriorHooks() { + return interior_hooks_; + } + + const std::shared_ptr& GetLeafHooks() const { + return leaf_hooks_; + } + + std::shared_ptr& GetLeafHooks() { return leaf_hooks_; } + private: void SetGradVar(const std::shared_ptr& var) { auto shared_var = grad_var_.lock(); @@ -159,6 +200,41 @@ class VariableWrapper { } } + /* Hook related private methods */ + std::shared_ptr GetGradVarSafely() const { + auto shared_grad_var = grad_var_.lock(); + PADDLE_ENFORCE_NOT_NULL( + shared_grad_var, + platform::errors::PermissionDenied( + "Cannot add gradient hook on Tensor without gradient.")); + return shared_grad_var; + } + + std::shared_ptr& GetGradVarInteriorHooksSafely() { + auto shared_grad_var = GetGradVarSafely(); + PADDLE_ENFORCE_EQ(HasGradNode(), true, + platform::errors::PermissionDenied( + "Only interior Tensor in backward can register " + "interior gradient hook.")); + if (shared_grad_var->interior_hooks_ == nullptr) { + shared_grad_var->interior_hooks_ = + std::make_shared(); + } + return shared_grad_var->interior_hooks_; + } + + std::shared_ptr& GetGradVarLeafHooksSafely() { + auto shared_grad_var = GetGradVarSafely(); + PADDLE_ENFORCE_EQ( + HasGradNode(), false, + platform::errors::PermissionDenied( + "Only leaf Tensor in backward can register leaf gradient hook.")); + if (shared_grad_var->leaf_hooks_ == nullptr) { + shared_grad_var->leaf_hooks_ = std::make_shared(); + } + return shared_grad_var->leaf_hooks_; + } + private: framework::Variable var_; std::string name_; @@ -173,6 +249,12 @@ class VariableWrapper { std::weak_ptr grad_var_; std::weak_ptr grad_node_; + + // NOTE: only grad var can hold hooks now + // only interior var can hold interior hooks + std::shared_ptr interior_hooks_; + // only leaf var can hold leaf hooks + std::shared_ptr leaf_hooks_; }; } // namespace imperative -- GitLab