hooks.h 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace imperative {

class VariableWrapper;

26
/** [ VariableWrapper Hook ]
27
 *
28 29 30 31 32 33
 * @brief This hook functor is executed before the grad OpBase is executed or
 *        after gradient accumulation completed in current batch.
 *        1. For interior var, VariableWrapper Hook take the input of the
 *        current grad OpBase as input.
 *        2. For leaf var, VariableWrapper Hook take the inner_var_ of
 *        GradientAccumulator as input.
34
 *
35 36 37 38
 * @note  This hook functor will not change the input gradient VariableWrapper,
 *        but if you copy the input VariableWrapper and change the value of
 *        Variable in VariableWrapper, the value of input will also be changed,
 *        because they shared same PlaceHolder.
39
 *
40
 * @note  [ Why need to be OpBase `PreHook`, why not `PostHook`? ]
41
 *
42
 *        We expect If set OpBase post hook, when the op executed end, the
43 44 45 46
 *        op's output gradient may not be the final state, because it may need
 *        other op's gradient output to accumulated to it. But before op can
 *        be executed, the gradient output must have been accumulated to final
 *        value.
47
 *
48
 * @note  [ Why Leaf gradient is special? ]
49 50 51
 *
 *        Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
 *        GradVarBase has no next OpBase to executed, so if need to deal with
52 53
 *        the leaf GradVarBase, we should call hooks after gradient accumulation
 *        completed.
54
 */
55
class VariableWrapperHook {
56
 public:
57 58 59
  virtual ~VariableWrapperHook() = default;
  virtual std::shared_ptr<VariableWrapper> operator()(
      const std::shared_ptr<VariableWrapper>& var) = 0;
60 61
};

62
class CppVariableWrapperHook : public VariableWrapperHook {
63
 public:
64 65 66
  explicit CppVariableWrapperHook(
      std::function<std::shared_ptr<VariableWrapper>(
          const std::shared_ptr<VariableWrapper>&)>&& fn)
67 68
      : fn_(std::move(fn)) {}

69 70 71 72
  std::shared_ptr<VariableWrapper> operator()(
      const std::shared_ptr<VariableWrapper>& var) override {
    return fn_(var);
  }
73 74

 private:
75 76 77
  std::function<std::shared_ptr<VariableWrapper>(
      const std::shared_ptr<VariableWrapper>&)>
      fn_;
78 79 80 81
};

}  // namespace imperative
}  // namespace paddle