op_info.h 5.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14 15 16 17 18 19

#pragma once
#include <functional>
#include <map>
#include <string>
#include <unordered_map>
S
sneaxiy 已提交
20
#include <vector>
21

Y
Yi Wang 已提交
22
#include "paddle/fluid/framework/attribute.h"
23
#include "paddle/fluid/framework/attribute_checker.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
Y
Yi Wang 已提交
25
#include "paddle/fluid/framework/type_defs.h"
W
wanghuancoder 已提交
26
#include "paddle/fluid/platform/enforce.h"
Y
Yi Wang 已提交
27
#include "paddle/fluid/platform/macros.h"
C
chentianyu03 已提交
28
#include "paddle/utils/flat_hash_map.h"
Y
Yu Yang 已提交
29 30 31 32

namespace paddle {
namespace framework {

W
wanghuancoder 已提交
33 34 35
class InferShapeContext;
class OpAttrChecker;

36 37 38 39 40 41
class InferShapeBase {
 public:
  virtual ~InferShapeBase() = default;
  virtual void operator()(InferShapeContext*) const = 0;
};

42 43
class OpInfo {
 public:
Y
Yu Yang 已提交
44
  OpCreator creator_;
Y
Yu Yang 已提交
45
  GradOpMakerFN grad_op_maker_;
46
  proto::OpProto* proto_{nullptr};
47
  OpAttrChecker* checker_{nullptr};
Y
Yu Yang 已提交
48
  InferVarTypeFN infer_var_type_;
49
  InferShapeFN infer_shape_;
D
dzhwinter 已提交
50
  InferInplaceOpFN infer_inplace_;
S
sneaxiy 已提交
51
  InferNoNeedBufferVarsFN infer_no_need_buffer_vars_;
H
hong 已提交
52
  DygraphGradOpMakerFN dygraph_grad_op_maker_;
Y
Yu Yang 已提交
53

S
sneaxiy 已提交
54 55 56 57
  // NOTE(zjl): this flag is added to check whether
  // the grad maker is the default one.
  bool use_default_grad_op_desc_maker_{false};

58 59 60 61
  // NOTE(huihuangzheng): this flag is added to check whether
  // the grad maker is the empty one.
  bool use_empty_grad_op_desc_maker_{false};

Y
Yu Yang 已提交
62 63 64 65
  bool HasOpProtoAndChecker() const {
    return proto_ != nullptr && checker_ != nullptr;
  }

66
  const proto::OpProto& Proto() const {
67 68 69
    PADDLE_ENFORCE_NOT_NULL(
        proto_,
        platform::errors::NotFound("Operator's Proto has not been registered"));
70 71
    PADDLE_ENFORCE_EQ(proto_->IsInitialized(),
                      true,
72 73
                      platform::errors::InvalidArgument(
                          "Operator's Proto in op info is not initialized."));
Y
Yu Yang 已提交
74 75 76 77
    return *proto_;
  }

  const OpCreator& Creator() const {
78
    PADDLE_ENFORCE_NOT_NULL(creator_,
79 80
                            platform::errors::NotFound(
                                "Operator's Creator has not been registered."));
Y
Yu Yang 已提交
81 82
    return creator_;
  }
Y
Yu Yang 已提交
83 84

  const GradOpMakerFN& GradOpMaker() const {
85 86 87 88 89
    // Normally, proto_ should not be null, except some special operators, such
    // as LeaklyReluDoubleGrad op.
    std::string type = proto_ ? proto_->type() : "unknown";
    PADDLE_ENFORCE_NOT_NULL(
        grad_op_maker_,
90 91 92 93 94
        platform::errors::NotFound(
            "Operator %s's GradOpMaker has not been "
            "registered.\nPlease check whether (%s) operator has "
            "gradient operator.\nIf not, please set stop_gradient to be True "
            "for its input and output variables using var.stop_gradient=True.",
95 96
            type.c_str(),
            type.c_str()));
Y
Yu Yang 已提交
97 98
    return grad_op_maker_;
  }
F
fengjiayi 已提交
99

100
  // some ops don't have grad_op_maker, add check before use GradOpMaker()
101
  bool HasGradOpMaker() const { return grad_op_maker_ != nullptr; }
102

103 104 105 106
  bool HasNonEmptyGradOpMaker() const {
    return grad_op_maker_ != nullptr && !use_empty_grad_op_desc_maker_;
  }

H
hong 已提交
107 108 109 110 111 112
  const DygraphGradOpMakerFN& DygraphGradOpMaker() const {
    // Normally, proto_ should not be null, except some special operators, such
    // as LeaklyReluDoubleGrad op.
    std::string type = proto_ ? proto_->type() : "unknown";
    PADDLE_ENFORCE_NOT_NULL(
        dygraph_grad_op_maker_,
113 114 115 116 117
        platform::errors::NotFound(
            "Operator %s's DygraphGradOpMaker has not been "
            "registered.\nPlease check whether (%s) operator has "
            "gradient operator.\nIf not, please set stop_gradient to be True "
            "for its input and output variables using var.stop_gradient=True.",
118 119
            type.c_str(),
            type.c_str()));
H
hong 已提交
120 121 122 123
    return dygraph_grad_op_maker_;
  }

  bool HasDygraphGradOpMaker() const {
124
    return dygraph_grad_op_maker_ != nullptr;
H
hong 已提交
125 126
  }

127
  bool HasInferInplace() const { return infer_inplace_ != nullptr; }
128

F
fengjiayi 已提交
129
  const OpAttrChecker* Checker() const { return checker_; }
S
sneaxiy 已提交
130 131 132 133

  const InferNoNeedBufferVarsFN& NoNeedBufferVarsInferer() const {
    return infer_no_need_buffer_vars_;
  }
Y
Yu Yang 已提交
134 135
};

Y
Yu Yang 已提交
136 137 138 139 140 141 142 143
class OpInfoMap {
 public:
  static OpInfoMap& Instance();

  bool Has(const std::string& op_type) const {
    return map_.find(op_type) != map_.end();
  }

144
  void Insert(const std::string& type, const OpInfo& info) {
145 146
    PADDLE_ENFORCE_NE(Has(type),
                      true,
147 148
                      platform::errors::AlreadyExists(
                          "Operator (%s) has been registered.", type));
Y
Yu Yang 已提交
149 150 151 152
    map_.insert({type, info});
  }

  const OpInfo& Get(const std::string& type) const {
153
    auto op_info_ptr = GetNullable(type);
154 155 156
    PADDLE_ENFORCE_NOT_NULL(
        op_info_ptr,
        platform::errors::NotFound("Operator (%s) is not registered.", type));
157 158 159 160
    return *op_info_ptr;
  }

  const OpInfo* GetNullable(const std::string& type) const {
Y
Yu Yang 已提交
161
    auto it = map_.find(type);
162 163 164 165 166
    if (it == map_.end()) {
      return nullptr;
    } else {
      return &it->second;
    }
Y
Yu Yang 已提交
167 168
  }

C
chentianyu03 已提交
169
  const paddle::flat_hash_map<std::string, OpInfo>& map() const { return map_; }
170

C
chentianyu03 已提交
171
  paddle::flat_hash_map<std::string, OpInfo>* mutable_map() { return &map_; }
Y
Yu Yang 已提交
172

S
sneaxiy 已提交
173 174
  std::vector<std::string> GetUseDefaultGradOpDescMakerOps() const;

Y
Yu Yang 已提交
175 176
 private:
  OpInfoMap() = default;
C
chentianyu03 已提交
177
  paddle::flat_hash_map<std::string, OpInfo> map_;
D
format  
dongzhihong 已提交
178 179

  DISABLE_COPY_AND_ASSIGN(OpInfoMap);
Y
Yu Yang 已提交
180
};
Y
Yu Yang 已提交
181 182 183

}  // namespace framework
}  // namespace paddle