未验证 提交 7eeb99fe 编写于 作者: C Chen Weihang 提交者: GitHub

Add basic hook classes for dygraph & implement reduce hook (#28584)

* add base hook classes and reduce hook impl

* fix constructor typo

* polish comment format

* refactor baisc hook class design

* polish design details
上级 858ffa0c
......@@ -114,6 +114,16 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
accumulator->IncreaseRefCnt();
if (var->HasLeafHooks()) {
VLOG(3) << "Grad variable wrapper (" << var->Name()
<< ") has leaf grad hooks.";
PADDLE_ENFORCE_NE(var->HasGradNode(), true,
platform::errors::PermissionDenied(
"Only leaf Tensor's gradient can append hook to "
"Gradientaccumulator."));
accumulator->SetPostHooks(var->GetLeafHooks());
}
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
<< var.get() << ") with reference count "
<< accumulator->RefCnt();
......@@ -204,6 +214,7 @@ void BasicEngine::Execute() {
var->Name()));
if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) {
no_need_run_accumulators_.emplace_back(iter->second.get());
continue;
}
......@@ -220,12 +231,19 @@ void BasicEngine::Execute() {
cur_op.place());
}
// Step 2: Sum Gradient
// Step 2: Sum Gradient & Call Accumulator Hooks
for (auto* accumulator : no_need_run_accumulators_) {
if (accumulator->HasPostHooks()) {
accumulator->CallBackwardPostHooks();
}
}
for (auto& pair : need_accu_var_list_) {
pair.first->Add(std::move(pair.second), cur_op.id());
}
need_accu_var_list_.clear();
no_need_run_accumulators_.clear();
VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
if (!retain_graph_) {
......@@ -258,6 +276,7 @@ void BasicEngine::Clear() {
node_deps_.clear();
accumulators_.clear();
need_accu_var_list_.clear();
no_need_run_accumulators_.clear();
}
} // namespace imperative
......
......@@ -49,6 +49,9 @@ class BasicEngine : public Engine {
accumulators_;
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
// Accumulators that does not need to perform accumulation operations,
// the ref_cnt_=1, corresponding to need_accu_var_list_
std::vector<GradientAccumulator*> no_need_run_accumulators_;
bool retain_graph_;
};
......
......@@ -401,13 +401,15 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
}
}
}
++cur_cnt_;
if (var_->Var().IsType<framework::LoDTensor>()) {
var_->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (var_->Var().IsType<framework::SelectedRows>()) {
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
}
// Increase count & call post hooks
IncreaseCurCnt();
}
void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
......@@ -520,6 +522,11 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
} else if (var_->Var().IsType<framework::SelectedRows>()) {
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
}
// call post hooks
if (HasPostHooks()) {
CallBackwardPostHooks();
}
}
} // namespace imperative
......
......@@ -17,6 +17,8 @@
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/layer.h"
namespace paddle {
......@@ -35,9 +37,43 @@ class GradientAccumulator {
inline size_t RefCnt() const { return ref_cnt_; }
/* Hook related methods */
inline bool HasPostHooks() const { return !post_hooks_.expired(); }
void SetPostHooks(const std::shared_ptr<LeafVarHookPipeline>& hooks) {
PADDLE_ENFORCE_NOT_NULL(
hooks, platform::errors::InvalidArgument(
"The hook set to GradientAccumulator is nullptr."));
auto shared_hooks = post_hooks_.lock();
if (shared_hooks != hooks) {
PADDLE_ENFORCE_EQ(
shared_hooks, nullptr,
platform::errors::PermissionDenied(
"Cannot set post hooks twice to GradientAccumulator."));
post_hooks_ = hooks;
}
}
// call backward post hooks, such as reduce hook
void CallBackwardPostHooks() {
PADDLE_ENFORCE_NE(
post_hooks_.expired(), true,
platform::errors::NotFound(
"The post hooks of GradientAccumulator for Tensor `%s` expired.",
var_->Name()));
auto shared_hooks = post_hooks_.lock();
for (const auto& hook : shared_hooks->backward_hooks()) {
VLOG(3) << "call gradient accumulator backward hooks.";
(*hook)(var_);
}
}
protected:
VariableWrapper* var_;
size_t ref_cnt_{0};
std::weak_ptr<LeafVarHookPipeline> post_hooks_;
};
class EagerGradientAccumulator : public GradientAccumulator {
......@@ -47,6 +83,19 @@ class EagerGradientAccumulator : public GradientAccumulator {
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id,
bool unchange_input) override;
private:
inline bool AccumulateCompleted() const { return cur_cnt_ == ref_cnt_; }
void IncreaseCurCnt() {
++cur_cnt_;
VLOG(3) << "IncreaseCurCnt: cur_cnt " << cur_cnt_ << ", ref_cnt "
<< ref_cnt_;
// After all tmp gradient being accumulated to grad var, run hooks
if (AccumulateCompleted() && HasPostHooks()) {
CallBackwardPostHooks();
}
}
private:
size_t cur_cnt_{0};
};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace imperative {
class VariableWrapper;
/** [ Basic hook classes ]
* s
* @brief OpBasePreHook is executed before the grad OpBase is executed,
* taking the input of the current grad OpBase as input, and
* executing python hooks (user-defined) or C++ hooks (developer-defined)
* to achieve the purpose of custom operations on the interior VarBase
* gradient.
*
* @note OpBasePreHook will not change the input gradient VarBase.
*
* @note [Why need to be OpBase `PreHook`, why not `PostHook`?]
*
* If set OpBase post hook, when the op executed end, the op's output
* gradient may not be the final state, because it may need other op's
* gradient output to accumulated to it. But before op can be executed,
* the gradient output must have been accumulated to final value.
*
* @note [Why only can be used for interior VarBase?]
*
* Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
* GradVarBase has no next OpBase to executed, so if need to deal with
* the leaf GradVarBase, cannot use OpBasePreHook. For this case, we
* deal with by GradAccumulatorPostHook.
*/
class OpBasePreHook {
public:
virtual ~OpBasePreHook() = default;
virtual VariableWrapperList operator()(
const VariableWrapperList& grad_inputs) = 0;
};
/**
* @brief GradAccumulatorPostHook is the Hook that operates on the current
* gradientafter the GradientAccumulator has accumulated the gradient.
* Leaf GradVarBase has no next OpBase, if we want to register hook
* for it, we also need to wait until the leaf GradVarBase accumulation
* is completed, so we can add post hook to GradientAccumulator.
*
* @note GradAccumulatorPostHook will change the grad VarBase value.
*
* @note Only allow leaf VarBase hold GradientAccumulatorPostHook.
*/
class GradAccumulatorPostHook {
public:
virtual ~GradAccumulatorPostHook() = default;
virtual void operator()(VariableWrapper* var) = 0;
};
/** [ Hook for cpp functions ]
*
* Here we design three C++ hooks;
* 1. CppOpBasePreHook (Implement later):
* - used for developer-defined C++ interior VarBase hooks
* 2. CppGradAccumulatorPostHook (Implement later):
* - used for developer-defined C++ leaf VarBase hooks
* 3. LambdaGradAccumulatorPostHook:
* - used for VarBase reduce in parallel training
*
* @note [Why need two types of GradAccumulatorPostHook? ]
*
* There are two types of gradient accumulation:
* 1. Gradient accumulation in same batch
* 2. Gradient accumulation across batchs
* The order of execution between Hooks and gradient accumulation:
*
* [ Gradient accumulation in same batch]
* |
* [ leaf GradVarBase hooks ]
* |
* [ Gradient accumulation across batchs ]
* |
* [ Gradient reduce / allreduce]
*
* Because we currently intend to accumulate these two gradient
* accumulation in one GradientAccumulator, We must distinguish between
* two types of hooks.
*
* And the LambdaGradAccumulatorPostHook does not allow users to register
* directly, and is currently only used to support the reduce strategy of
* parallel multi-card training.
*/
class LambdaGradAccumulatorPostHook : public GradAccumulatorPostHook {
public:
explicit LambdaGradAccumulatorPostHook(
std::function<void(VariableWrapper*)> fn)
: fn_(std::move(fn)) {}
void operator()(VariableWrapper* var) override { fn_(var); }
private:
std::function<void(VariableWrapper*)> fn_;
};
/* Hooks for python function: in pybind/imperative.cc */
/** Add Python Hooks later:
* - PyOpBasePreHook (Implement later): used for user-defined interior python
* VarBase hooks
* - PyGradAccumulatorPostHook (Implement later): used for user-defined leaf
* python VarBase hooks
*/
/** [ Hook Pipeline classes ]
*
* @note [Why need hook pipeline classes?]
*
* There are 2 purposes for adding Hook pipeline here:
*
* 1. Make the code implementation cleaner.
*
* If there are no Hook pipeline, we need to add 3 hook vector into
* VariableWrapper, 1 hook vector into OpBase, 2 hook vector into
* GradientAccumulator, like:
*
* - VariableWrapper:
* std::vector<std::shared_ptr<OpBasePreHook>>
* interior_var_hooks_;
* std::vector<std::shared_ptr<GradAccumulatorPostHook>>
* leaf_var_hooks_;
* std::vector<std::shared_ptr<GradAccumulatorPostHook>>
* backward_hooks_;
*
* - OpBase:
* std::vector<std::weak_ptr<OpBasePreHook>>
* interior_var_hooks_;
*
* - GradientAccumulator:
* std::vector<std::weak_ptr<GradAccumulatorPostHook>>
* leaf_var_hooks_;
* std::vector<std::weak_ptr<GradAccumulatorPostHook>>
* backward_hooks_;
*
* This seems more complicated, and std::vector<std::weak_ptr<...>>
* is not easy to destruct.
*
* 2. Make the code easier to understand.
*
* From these two packages, we can clearly understand that we
* have two types of Hooks, respectively for the interior
* gradient var and leaf gradient var inside the backward
* calculation graph.
*/
class InteriorVarHookPipeline {
public:
InteriorVarHookPipeline() = default;
void add_hook(std::unique_ptr<OpBasePreHook>&& hook) {
hooks_.emplace_back(std::move(hook));
}
const std::vector<std::unique_ptr<OpBasePreHook>>& hooks() const {
return hooks_;
}
std::vector<std::unique_ptr<OpBasePreHook>>& hooks() { return hooks_; }
private:
std::vector<std::unique_ptr<OpBasePreHook>> hooks_;
DISABLE_COPY_AND_ASSIGN(InteriorVarHookPipeline);
};
class LeafVarHookPipeline {
public:
LeafVarHookPipeline() = default;
void add_hook(std::unique_ptr<GradAccumulatorPostHook>&& hook) {
hooks_.emplace_back(std::move(hook));
}
const std::vector<std::unique_ptr<GradAccumulatorPostHook>>& hooks() const {
return hooks_;
}
std::vector<std::unique_ptr<GradAccumulatorPostHook>>& hooks() {
return hooks_;
}
void add_backward_hook(std::unique_ptr<GradAccumulatorPostHook>&& hook) {
backward_hooks_.emplace_back(std::move(hook));
}
const std::vector<std::unique_ptr<GradAccumulatorPostHook>>& backward_hooks()
const {
return backward_hooks_;
}
std::vector<std::unique_ptr<GradAccumulatorPostHook>>& backward_hooks() {
return backward_hooks_;
}
private:
std::vector<std::unique_ptr<GradAccumulatorPostHook>> hooks_;
// NOTE: the `backward` here means the `whole backward process`,
// the `backward_hooks_` need to be executed after the `whole backward
// process`.
std::vector<std::unique_ptr<GradAccumulatorPostHook>> backward_hooks_;
DISABLE_COPY_AND_ASSIGN(LeafVarHookPipeline);
};
} // namespace imperative
} // namespace paddle
......@@ -176,7 +176,7 @@ class OpBase {
platform::Place place_;
size_t id_{-1UL};
std::vector<std::function<void()>> backward_hooks_;
std::weak_ptr<InteriorVarHookPipeline> pre_hooks_;
};
class GradOpNode {
......
......@@ -11,3 +11,4 @@ cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy se
cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy)
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/memory/memcpy.h"
namespace platform = paddle::platform;
namespace framework = paddle::framework;
namespace memory = paddle::memory;
DECLARE_bool(sort_sum_gradient);
namespace paddle {
namespace imperative {
using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using var_pair = std::pair<std::string, vb_vector>;
TEST(TestHooks, TestGradVarLeafBackwardHook) {
// 1. prepare
Tracer tracer;
std::shared_ptr<VarBase> x(new VarBase(true, "x"));
std::shared_ptr<VarBase> y(new VarBase(true, "y"));
std::shared_ptr<VarBase> out(new VarBase(true, "out"));
x->SetOverridedStopGradient(false);
y->SetOverridedStopGradient(false);
platform::CPUPlace place;
std::vector<float> src_data(10, 2.0);
std::vector<int64_t> x_dims = {2, 5};
std::vector<int64_t> y_dims = {5, 2};
auto* x_tensor = x->MutableVar()->GetMutable<framework::LoDTensor>();
auto* y_tensor = y->MutableVar()->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(x_dims));
auto* mutable_x = x_tensor->mutable_data<float>(place);
memory::Copy(place, mutable_x, place, src_data.data(),
sizeof(float) * src_data.size());
y_tensor->Resize(framework::make_ddim(y_dims));
auto* mutable_y = y_tensor->mutable_data<float>(place);
memory::Copy(place, mutable_y, place, src_data.data(),
sizeof(float) * src_data.size());
var_pair x_pair = var_pair("X", vb_vector(1, x));
var_pair y_pair = var_pair("Y", vb_vector(1, y));
var_pair out_pair = var_pair("Out", vb_vector(1, out));
NameVarBaseMap ins = {x_pair, y_pair};
NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
// add GradAccumulatorPostHook
auto x_var_wrapper = x->SharedVar();
x_var_wrapper->AddGradVarLeafBackwardHook(
std::unique_ptr<LambdaGradAccumulatorPostHook>(
new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) {
auto* grad_tensor =
grad->MutableVar()->GetMutable<framework::LoDTensor>();
for (int i = 0; i < grad_tensor->numel(); ++i) {
grad_tensor->mutable_data<float>(place)[i] *= 2.0;
}
})));
// 2. forward
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(out->GradVarBase()->GradOpNum(), 1UL);
// 3. backward
BasicEngine engine;
engine.Init(out.get());
engine.Execute();
framework::LoDTensor x_grad;
framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place,
&x_grad);
for (int i = 0; i < x_grad.numel(); ++i) {
ASSERT_EQ(x_grad.data<float>()[i], 8.0);
}
framework::LoDTensor y_grad;
framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place,
&y_grad);
for (int i = 0; i < y_grad.numel(); ++i) {
ASSERT_EQ(y_grad.data<float>()[i], 4.0);
}
}
void GradVarLeafBackwardHookWithGradAccmulatedTest() {
// 1. prepare
Tracer tracer;
std::shared_ptr<VarBase> x(new VarBase(true, "x"));
std::shared_ptr<VarBase> y(new VarBase(true, "y"));
std::shared_ptr<VarBase> z(new VarBase(true, "z"));
std::shared_ptr<VarBase> out_xy(new VarBase(true, "out_xy"));
std::shared_ptr<VarBase> out_xz(new VarBase(true, "out_xz"));
std::shared_ptr<VarBase> out(new VarBase(true, "out"));
x->SetOverridedStopGradient(false);
y->SetOverridedStopGradient(false);
z->SetOverridedStopGradient(false);
platform::CPUPlace place;
std::vector<float> src_data(10, 2.0);
std::vector<int64_t> x_dims = {2, 5};
std::vector<int64_t> y_dims = {5, 2};
std::vector<int64_t> z_dims = {5, 2};
auto* x_tensor = x->MutableVar()->GetMutable<framework::LoDTensor>();
auto* y_tensor = y->MutableVar()->GetMutable<framework::LoDTensor>();
auto* z_tensor = z->MutableVar()->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(x_dims));
auto* mutable_x = x_tensor->mutable_data<float>(place);
memory::Copy(place, mutable_x, place, src_data.data(),
sizeof(float) * src_data.size());
y_tensor->Resize(framework::make_ddim(y_dims));
auto* mutable_y = y_tensor->mutable_data<float>(place);
memory::Copy(place, mutable_y, place, src_data.data(),
sizeof(float) * src_data.size());
z_tensor->Resize(framework::make_ddim(z_dims));
auto* mutable_z = z_tensor->mutable_data<float>(place);
memory::Copy(place, mutable_z, place, src_data.data(),
sizeof(float) * src_data.size());
// add GradAccumulatorPostHook
auto x_var_wrapper = x->SharedVar();
x_var_wrapper->AddGradVarLeafBackwardHook(
std::unique_ptr<LambdaGradAccumulatorPostHook>(
new LambdaGradAccumulatorPostHook([=](VariableWrapper* grad) {
auto* grad_tensor =
grad->MutableVar()->GetMutable<framework::LoDTensor>();
for (int i = 0; i < grad_tensor->numel(); ++i) {
grad_tensor->mutable_data<float>(place)[i] *= 2.0;
}
})));
// 2. forward
var_pair x_pair = var_pair("X", vb_vector(1, x));
var_pair y_pair = var_pair("Y", vb_vector(1, y));
var_pair out_xy_pair = var_pair("Out", vb_vector(1, out_xy));
NameVarBaseMap ins = {x_pair, y_pair};
NameVarBaseMap outs = {out_xy_pair};
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
var_pair z_pair = var_pair("Y", vb_vector(1, z));
var_pair out_xz_pair = var_pair("Out", vb_vector(1, out_xz));
ins = {x_pair, z_pair};
outs = {out_xz_pair};
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
var_pair xy_pair = var_pair("X", vb_vector(1, out_xy));
var_pair xz_pair = var_pair("Y", vb_vector(1, out_xz));
var_pair out_pair = var_pair("Out", vb_vector(1, out));
ins = {xy_pair, xz_pair};
outs = {out_pair};
framework::AttributeMap add_attr_map;
tracer.TraceOp("elementwise_add", ins, outs, add_attr_map, place, true);
ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(z->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(out->GradVarBase()->GradOpNum(), 1UL);
// 3. backward
BasicEngine engine;
engine.Init(out.get());
engine.Execute();
framework::LoDTensor x_grad;
framework::TensorCopySync(x->GradVar().Get<framework::LoDTensor>(), place,
&x_grad);
for (int i = 0; i < x_grad.numel(); ++i) {
ASSERT_EQ(x_grad.data<float>()[i], 16.0);
}
framework::LoDTensor y_grad;
framework::TensorCopySync(y->GradVar().Get<framework::LoDTensor>(), place,
&y_grad);
for (int i = 0; i < y_grad.numel(); ++i) {
ASSERT_EQ(y_grad.data<float>()[i], 4.0);
}
framework::LoDTensor z_grad;
framework::TensorCopySync(z->GradVar().Get<framework::LoDTensor>(), place,
&z_grad);
for (int i = 0; i < z_grad.numel(); ++i) {
ASSERT_EQ(z_grad.data<float>()[i], 4.0);
}
}
TEST(TestHooks, TestGradVarLeafBackwardHookWithGradAccmulated) {
GradVarLeafBackwardHookWithGradAccmulatedTest();
}
TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) {
FLAGS_sort_sum_gradient = true;
GradVarLeafBackwardHookWithGradAccmulatedTest();
FLAGS_sort_sum_gradient = false;
}
} // namespace imperative
} // namespace paddle
USE_OP(mul);
USE_OP(mul_grad);
USE_OP(elementwise_add);
USE_OP(elementwise_add_grad);
......@@ -16,11 +16,16 @@
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/hooks.h"
namespace paddle {
namespace imperative {
class InteriorVarHookPipeline;
class LeafVarHookPipeline;
class VarBase;
class GradOpNode;
......@@ -133,6 +138,42 @@ class VariableWrapper {
}
}
/* Hook related method: only can be call by GradVarBase */
bool HasInteriorHooks() const { return interior_hooks_ != nullptr; }
bool HasLeafHooks() const { return leaf_hooks_ != nullptr; }
void AddGradVarInteriorHook(std::unique_ptr<OpBasePreHook>&& hook) {
auto interior_hooks = GetGradVarInteriorHooksSafely();
interior_hooks->add_hook(std::move(hook));
}
void AddGradVarLeafHook(std::unique_ptr<GradAccumulatorPostHook>&& hook) {
auto leaf_hooks = GetGradVarLeafHooksSafely();
leaf_hooks->add_hook(std::move(hook));
}
void AddGradVarLeafBackwardHook(
std::unique_ptr<GradAccumulatorPostHook>&& hook) {
auto leaf_hooks = GetGradVarLeafHooksSafely();
leaf_hooks->add_backward_hook(std::move(hook));
}
const std::shared_ptr<InteriorVarHookPipeline>& GetInteriorHooks() const {
return interior_hooks_;
}
std::shared_ptr<InteriorVarHookPipeline>& GetInteriorHooks() {
return interior_hooks_;
}
const std::shared_ptr<LeafVarHookPipeline>& GetLeafHooks() const {
return leaf_hooks_;
}
std::shared_ptr<LeafVarHookPipeline>& GetLeafHooks() { return leaf_hooks_; }
private:
void SetGradVar(const std::shared_ptr<VariableWrapper>& var) {
auto shared_var = grad_var_.lock();
......@@ -159,6 +200,41 @@ class VariableWrapper {
}
}
/* Hook related private methods */
std::shared_ptr<VariableWrapper> GetGradVarSafely() const {
auto shared_grad_var = grad_var_.lock();
PADDLE_ENFORCE_NOT_NULL(
shared_grad_var,
platform::errors::PermissionDenied(
"Cannot add gradient hook on Tensor without gradient."));
return shared_grad_var;
}
std::shared_ptr<InteriorVarHookPipeline>& GetGradVarInteriorHooksSafely() {
auto shared_grad_var = GetGradVarSafely();
PADDLE_ENFORCE_EQ(HasGradNode(), true,
platform::errors::PermissionDenied(
"Only interior Tensor in backward can register "
"interior gradient hook."));
if (shared_grad_var->interior_hooks_ == nullptr) {
shared_grad_var->interior_hooks_ =
std::make_shared<InteriorVarHookPipeline>();
}
return shared_grad_var->interior_hooks_;
}
std::shared_ptr<LeafVarHookPipeline>& GetGradVarLeafHooksSafely() {
auto shared_grad_var = GetGradVarSafely();
PADDLE_ENFORCE_EQ(
HasGradNode(), false,
platform::errors::PermissionDenied(
"Only leaf Tensor in backward can register leaf gradient hook."));
if (shared_grad_var->leaf_hooks_ == nullptr) {
shared_grad_var->leaf_hooks_ = std::make_shared<LeafVarHookPipeline>();
}
return shared_grad_var->leaf_hooks_;
}
private:
framework::Variable var_;
std::string name_;
......@@ -173,6 +249,12 @@ class VariableWrapper {
std::weak_ptr<VariableWrapper> grad_var_;
std::weak_ptr<GradOpNode> grad_node_;
// NOTE: only grad var can hold hooks now
// only interior var can hold interior hooks
std::shared_ptr<InteriorVarHookPipeline> interior_hooks_;
// only leaf var can hold leaf hooks
std::shared_ptr<LeafVarHookPipeline> leaf_hooks_;
};
} // namespace imperative
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册