no_need_buffer_vars_inference.h 4.6 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <memory>
S
sneaxiy 已提交
18 19 20
#include <string>
#include <unordered_set>
#include <vector>
21 22 23
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
S
sneaxiy 已提交
24 25 26 27

namespace paddle {
namespace framework {

28
class InferNoNeedBufferVarsContext {
S
sneaxiy 已提交
29
 public:
30 31 32
  explicit InferNoNeedBufferVarsContext(const framework::AttributeMap &attrs)
      : attrs_(attrs) {}
  virtual ~InferNoNeedBufferVarsContext() = default;
S
sneaxiy 已提交
33

34
  virtual bool HasOutput(const std::string &slot) const = 0;
S
sneaxiy 已提交
35

36
  const Attribute &GetAttr(const std::string &attr) const;
S
sneaxiy 已提交
37

38 39 40
 private:
  const framework::AttributeMap &attrs_;
};
S
sneaxiy 已提交
41

42 43 44 45 46 47
class StaticGraphInferNoNeedBufferVarsContext final
    : public InferNoNeedBufferVarsContext {
 public:
  StaticGraphInferNoNeedBufferVarsContext(const VariableNameMap &inputs,
                                          const VariableNameMap &outputs,
                                          const AttributeMap &attrs);
S
sneaxiy 已提交
48

49
  bool HasOutput(const std::string &slot) const final;
S
sneaxiy 已提交
50 51 52 53 54 55

 private:
  const VariableNameMap &inputs_;
  const VariableNameMap &outputs_;
};

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
class DyGraphInferNoNeedBufferVarsContext final
    : public InferNoNeedBufferVarsContext {
 public:
  DyGraphInferNoNeedBufferVarsContext(const imperative::NameVarBaseMap &inputs,
                                      const imperative::NameVarBaseMap &outputs,
                                      const AttributeMap &attr);

  bool HasOutput(const std::string &slot) const final;

 private:
  const imperative::NameVarBaseMap &inputs_;
  const imperative::NameVarBaseMap &outputs_;
};

class NoNeedBufferVarsInference {
 public:
  virtual ~NoNeedBufferVarsInference() = default;
  virtual const std::unordered_set<std::string> &operator()(
      const InferNoNeedBufferVarsContext &ctx) const = 0;

 protected:
  static const std::unordered_set<std::string> &Empty() {
    static std::unordered_set<std::string> empty;
    return empty;
  }
};

#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...)        \
  class class_type final                                              \
      : public ::paddle::framework::NoNeedBufferVarsInference {       \
   public:                                                            \
    using ::paddle::framework::NoNeedBufferVarsInference::            \
        NoNeedBufferVarsInference;                                    \
                                                                      \
    const std::unordered_set<std::string> &operator()(                \
        const ::paddle::framework::InferNoNeedBufferVarsContext &ctx) \
        const final {                                                 \
      static std::unordered_set<std::string> __ret__{__VA_ARGS__};    \
      return __ret__;                                                 \
    }                                                                 \
  }

class InferNoNeedBufferVarsFN {
 public:
  inline const std::unordered_set<std::string> &operator()(
      const VariableNameMap &inputs, const VariableNameMap &outputs,
      const AttributeMap &attrs) const {
    PADDLE_ENFORCE_NOT_NULL(inferer_);
    StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
    return (*inferer_)(ctx);
  }

  inline const std::unordered_set<std::string> &operator()(
      const imperative::NameVarBaseMap &inputs,
      const imperative::NameVarBaseMap &outputs,
      const AttributeMap &attrs) const {
    PADDLE_ENFORCE_NOT_NULL(inferer_);
    DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
    return (*inferer_)(ctx);
S
sneaxiy 已提交
115 116
  }

117
  inline explicit operator bool() const { return inferer_ != nullptr; }
118 119 120 121 122 123 124 125 126 127 128 129 130

  inline bool operator!() const { return inferer_ == nullptr; }

  inline void Reset(const std::shared_ptr<NoNeedBufferVarsInference> &inferer) {
    PADDLE_ENFORCE_NOT_NULL(inferer);
    PADDLE_ENFORCE_EQ(inferer_, nullptr);
    inferer_ = inferer;
  }

 private:
  std::shared_ptr<NoNeedBufferVarsInference> inferer_;
};

S
sneaxiy 已提交
131 132
}  // namespace framework
}  // namespace paddle