hooks.h 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace imperative {

class VariableWrapper;

26 27 28
/** [ Const VariableWrapper Hook: Pre hook functor of OpBase ]
 *
 * @brief This hook functor is executed before the grad OpBase is executed,
29 30 31 32 33
 *        taking the input of the current grad OpBase as input, and
 *        executing python hooks (user-defined) or C++ hooks (developer-defined)
 *        to achieve the purpose of custom operations on the interior VarBase
 *        gradient.
 *
34
 * @note  This hook functor will not change the input gradient VarBase.
35 36 37
 *
 * @note  [Why need to be OpBase `PreHook`, why not `PostHook`?]
 *
38 39 40 41 42 43 44
 *        1. We expect If set OpBase post hook, when the op executed end, the
 *        op's output gradient may not be the final state, because it may need
 *        other op's gradient output to accumulated to it. But before op can
 *        be executed, the gradient output must have been accumulated to final
 *        value.
 *        2. We don’t want the hook to change its input Tensor value, so now
 *        we can't call all hooks in GradAccumulator.
45 46 47 48 49
 *
 * @note  [Why only can be used for interior VarBase?]
 *
 *        Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
 *        GradVarBase has no next OpBase to executed, so if need to deal with
50 51
 *        the leaf GradVarBase, cannot use this hook functor. For this case, we
 *        deal with by other inplace hook method.
52
 */
53
class VariableWrapperHook {
54
 public:
55 56 57
  virtual ~VariableWrapperHook() = default;
  virtual std::shared_ptr<VariableWrapper> operator()(
      const std::shared_ptr<VariableWrapper>& var) = 0;
58 59
};

60 61 62
/** [ Inplace VariableWrapper Hook: Post hook functor of GradAccumulator ]
 *
 * @brief This hook functor is the Hook that operates on the current
63 64 65 66 67
 *        gradientafter the GradientAccumulator has accumulated the gradient.
 *        Leaf GradVarBase has no next OpBase, if we want to register hook
 *        for it, we also need to wait until the leaf GradVarBase accumulation
 *        is completed, so we can add post hook to GradientAccumulator.
 *
68
 * @note  This hook functor will change the grad VarBase value.
69
 *
70
 * @note  Only allow leaf VarBase hold call this hook functor.
71
 */
72
class InplaceVariableWrapperHook {
73
 public:
74
  virtual ~InplaceVariableWrapperHook() = default;
75 76 77
  virtual void operator()(VariableWrapper* var) = 0;
};

78
class LambdaInplaceVariableWrapperHook : public InplaceVariableWrapperHook {
79
 public:
80 81
  explicit LambdaInplaceVariableWrapperHook(
      std::function<void(VariableWrapper*)>&& fn)
82 83 84 85 86 87 88 89 90 91
      : fn_(std::move(fn)) {}

  void operator()(VariableWrapper* var) override { fn_(var); }

 private:
  std::function<void(VariableWrapper*)> fn_;
};

}  // namespace imperative
}  // namespace paddle