no_need_buffer_vars_inference.h 5.7 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <memory>
S
sneaxiy 已提交
18 19 20
#include <string>
#include <unordered_set>
#include <vector>
21 22 23
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
S
sneaxiy 已提交
24 25 26 27

namespace paddle {
namespace framework {

28
class InferNoNeedBufferVarsContext {
S
sneaxiy 已提交
29
 public:
30 31 32
  explicit InferNoNeedBufferVarsContext(const framework::AttributeMap &attrs)
      : attrs_(attrs) {}
  virtual ~InferNoNeedBufferVarsContext() = default;
S
sneaxiy 已提交
33

34
  virtual bool HasOutput(const std::string &slot) const = 0;
S
sneaxiy 已提交
35

36
  const Attribute &GetAttr(const std::string &attr) const;
S
sneaxiy 已提交
37

38 39 40
 private:
  const framework::AttributeMap &attrs_;
};
S
sneaxiy 已提交
41

42 43 44 45 46 47
class StaticGraphInferNoNeedBufferVarsContext final
    : public InferNoNeedBufferVarsContext {
 public:
  StaticGraphInferNoNeedBufferVarsContext(const VariableNameMap &inputs,
                                          const VariableNameMap &outputs,
                                          const AttributeMap &attrs);
S
sneaxiy 已提交
48

49
  bool HasOutput(const std::string &slot) const final;
S
sneaxiy 已提交
50 51 52 53 54 55

 private:
  const VariableNameMap &inputs_;
  const VariableNameMap &outputs_;
};

56 57 58
class DyGraphInferNoNeedBufferVarsContext final
    : public InferNoNeedBufferVarsContext {
 public:
59 60 61 62
  DyGraphInferNoNeedBufferVarsContext(
      const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
      const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
      const AttributeMap &attrs);
63 64 65 66

  bool HasOutput(const std::string &slot) const final;

 private:
67 68
  const imperative::NameVarMap<imperative::VariableWrapper> &inputs_;
  const imperative::NameVarMap<imperative::VariableWrapper> &outputs_;
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
};

class NoNeedBufferVarsInference {
 public:
  virtual ~NoNeedBufferVarsInference() = default;
  virtual const std::unordered_set<std::string> &operator()(
      const InferNoNeedBufferVarsContext &ctx) const = 0;

 protected:
  static const std::unordered_set<std::string> &Empty() {
    static std::unordered_set<std::string> empty;
    return empty;
  }
};

Z
Zeng Jinle 已提交
84
#define DECLARE_NO_NEED_BUFFER_VARS_INFERER(class_type, ...)          \
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  class class_type final                                              \
      : public ::paddle::framework::NoNeedBufferVarsInference {       \
   public:                                                            \
    using ::paddle::framework::NoNeedBufferVarsInference::            \
        NoNeedBufferVarsInference;                                    \
                                                                      \
    const std::unordered_set<std::string> &operator()(                \
        const ::paddle::framework::InferNoNeedBufferVarsContext &ctx) \
        const final {                                                 \
      static std::unordered_set<std::string> __ret__{__VA_ARGS__};    \
      return __ret__;                                                 \
    }                                                                 \
  }

class InferNoNeedBufferVarsFN {
 public:
  inline const std::unordered_set<std::string> &operator()(
      const VariableNameMap &inputs, const VariableNameMap &outputs,
      const AttributeMap &attrs) const {
104 105 106 107
    PADDLE_ENFORCE_NOT_NULL(
        inferer_,
        platform::errors::PreconditionNotMet(
            "The `inferer_` of InferNoNeedBufferVarsFN is not initialized."));
108 109 110 111 112
    StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
    return (*inferer_)(ctx);
  }

  inline const std::unordered_set<std::string> &operator()(
113 114
      const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
      const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
115
      const AttributeMap &attrs) const {
116 117 118 119
    PADDLE_ENFORCE_NOT_NULL(
        inferer_,
        platform::errors::PreconditionNotMet(
            "The `inferer_` of InferNoNeedBufferVarsFN is not initialized."));
120 121
    DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
    return (*inferer_)(ctx);
S
sneaxiy 已提交
122 123
  }

124
  inline explicit operator bool() const { return inferer_ != nullptr; }
125 126 127 128

  inline bool operator!() const { return inferer_ == nullptr; }

  inline void Reset(const std::shared_ptr<NoNeedBufferVarsInference> &inferer) {
129 130 131 132 133 134 135 136
    PADDLE_ENFORCE_NOT_NULL(
        inferer, platform::errors::InvalidArgument("The input inferer of "
                                                   "InferNoNeedBufferVarsFN::"
                                                   "Reset is nullptr."));
    PADDLE_ENFORCE_EQ(
        inferer_, nullptr,
        platform::errors::AlreadyExists(
            "The `inferer_` of InferNoNeedBufferVarsFN has been initialized."));
137 138 139
    inferer_ = inferer;
  }

Z
Zeng Jinle 已提交
140 141 142 143
  inline bool operator==(std::nullptr_t) const { return inferer_ == nullptr; }

  inline bool operator!=(std::nullptr_t) const { return inferer_ != nullptr; }

144 145 146 147
 private:
  std::shared_ptr<NoNeedBufferVarsInference> inferer_;
};

Z
Zeng Jinle 已提交
148 149 150 151 152 153 154 155 156 157
static inline bool operator==(std::nullptr_t,
                              const InferNoNeedBufferVarsFN &other) {
  return other == nullptr;
}

static inline bool operator!=(std::nullptr_t,
                              const InferNoNeedBufferVarsFN &other) {
  return other != nullptr;
}

S
sneaxiy 已提交
158 159
}  // namespace framework
}  // namespace paddle