op_registry.h 12.6 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16

#pragma once

17
#include <map>
18
#include <memory>
19 20
#include <string>
#include <tuple>
S
sneaxiy 已提交
21
#include <type_traits>
M
minqiyang 已提交
22 23
#include <unordered_map>
#include <unordered_set>
24
#include <vector>
25

Y
Yi Wang 已提交
26
#include "paddle/fluid/framework/grad_op_desc_maker.h"
D
dzhwinter 已提交
27
#include "paddle/fluid/framework/inplace_op_inference.h"
S
sneaxiy 已提交
28
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
Y
Yi Wang 已提交
29 30 31 32
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type_inference.h"
H
hong 已提交
33 34
#include "paddle/fluid/imperative/dygraph_grad_maker.h"
#include "paddle/fluid/imperative/type_defs.h"
J
Jiabin Yang 已提交
35
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
36 37 38 39 40 41 42 43

namespace paddle {
namespace framework {
namespace details {

enum OpInfoFillType {
  kOperator = 0,
  kOpProtoAndCheckerMaker = 1,
Y
Yu Yang 已提交
44
  kGradOpDescMaker = 2,
45
  kVarTypeInference = 3,
D
dzhwinter 已提交
46
  kShapeInference = 4,
S
sneaxiy 已提交
47 48
  kInplaceOpInference = 5,
  kNoNeedBufferVarsInference = 6,
H
hong 已提交
49
  kGradOpBaseMaker = 7,
J
Jiabin Yang 已提交
50
  kGradCompOpDescMaker = 8,
S
sneaxiy 已提交
51
  kUnknown = -1
52 53
};

S
sneaxiy 已提交
54 55 56 57 58 59 60 61 62 63 64
namespace internal {
template <typename T, OpInfoFillType kType>
struct TypePair {
  using Type = T;
  static constexpr OpInfoFillType kFillType = kType;
};

using OpRegistryClasses = std::tuple<                                // NOLINT
    TypePair<OperatorBase, kOperator>,                               // NOLINT
    TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>,       // NOLINT
    TypePair<GradOpDescMakerBase, kGradOpDescMaker>,                 // NOLINT
H
hong 已提交
65
    TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>,     // NOLINT
66
    TypePair<prim::CompositeGradOpMakerBase, kGradCompOpDescMaker>,  // NOLINT
S
sneaxiy 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    TypePair<VarTypeInference, kVarTypeInference>,                   // NOLINT
    TypePair<InferShapeBase, kShapeInference>,                       // NOLINT
    TypePair<InplaceOpInference, kInplaceOpInference>,               // NOLINT
    TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference>  // NOLINT
    >;

static constexpr int kOpRegistryClassNumber =
    std::tuple_size<OpRegistryClasses>::value;

template <typename T, int kPos, bool kIsBounded /* = true*/>
struct IsMatchedBaseTypeImpl {
  using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type;
  static constexpr bool kValue =
      std::is_base_of<typename PairType::Type, T>::value;
};

template <typename T, int kPos>
struct IsMatchedBaseTypeImpl<T, kPos, false> {
  static constexpr bool kValue = false;
};

template <typename T, int kPos>
static inline constexpr bool IsMatchedBaseType() {
90 91 92 93
  return IsMatchedBaseTypeImpl<T,
                               kPos,
                               (kPos >= 0 &&
                                kPos < kOpRegistryClassNumber)>::kValue;
S
sneaxiy 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
}

template <typename T, int kStart, int kEnd, bool kIsEnd, bool kIsMatched>
struct OpInfoFillTypeGetterImpl {};

// This case should not happen
template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, true> {};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, false> {
  static constexpr OpInfoFillType kType = kUnknown;
};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> {
  static constexpr OpInfoFillType kType =
111 112 113 114
      OpInfoFillTypeGetterImpl<T,
                               kStart + 1,
                               kEnd,
                               kStart + 1 == kEnd,
S
sneaxiy 已提交
115 116 117 118 119 120 121 122 123 124 125
                               IsMatchedBaseType<T, kStart + 1>()>::kType;
};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> {
  using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type;
  static constexpr OpInfoFillType kType = PairType::kFillType;
};

template <typename T>
using OpInfoFillTypeGetter =
126 127 128
    OpInfoFillTypeGetterImpl<T,
                             0,
                             kOpRegistryClassNumber,
S
sneaxiy 已提交
129 130 131 132 133
                             kOpRegistryClassNumber == 0,
                             IsMatchedBaseType<T, 0>()>;

}  // namespace internal

134 135 136
template <typename T>
struct OpInfoFillTypeID {
  static constexpr OpInfoFillType ID() {
S
sneaxiy 已提交
137
    return internal::OpInfoFillTypeGetter<T>::kType;
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  }
};

template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
struct OpInfoFiller;

template <size_t I, bool at_end, typename... ARGS>
class OperatorRegistrarRecursive;

template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, false, ARGS...> {
 public:
  using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
  OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
    OpInfoFiller<T> fill;
    fill(op_type, info);
    constexpr auto size = sizeof...(ARGS);
    OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
                                                                  info);
    (void)(reg);
  }
};

template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, true, ARGS...> {
 public:
  OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {}
};

template <typename T>
struct OpInfoFiller<T, kOperator> {
  void operator()(const char* op_type, OpInfo* info) const {
170 171
    PADDLE_ENFORCE_EQ(info->creator_,
                      nullptr,
Z
Zeng Jinle 已提交
172 173
                      platform::errors::AlreadyExists(
                          "OpCreator of %s has been registered", op_type));
174 175
    info->creator_ = [](const std::string& type,
                        const VariableNameMap& inputs,
176 177 178 179
                        const VariableNameMap& outputs,
                        const AttributeMap& attrs) {
      return new T(type, inputs, outputs, attrs);
    };
Z
Zeng Jinle 已提交
180 181 182

    if (std::is_base_of<OperatorWithKernel, T>::value) {
      PADDLE_ENFORCE_EQ(
183 184
          info->infer_shape_,
          nullptr,
Z
Zeng Jinle 已提交
185 186 187
          platform::errors::AlreadyExists(
              "Duplicate InferShapeFN of %s has been registered", op_type));

188 189
      OperatorWithKernel* op = dynamic_cast<OperatorWithKernel*>(info->creator_(
          std::string{}, VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
190 191 192
      PADDLE_ENFORCE_NOT_NULL(
          op,
          platform::errors::InvalidArgument("%s should have kernels", op_type));
Z
Zeng Jinle 已提交
193 194 195 196
      info->infer_shape_ = [op](InferShapeContext* ctx) {
        op->InferShape(ctx);
      };
    }
197 198 199 200 201 202
  }
};

template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
203 204
    PADDLE_ENFORCE_EQ(info->proto_,
                      nullptr,
Z
Zeng Jinle 已提交
205
                      platform::errors::AlreadyExists(
206
                          "OpProto of %s has been registered.", op_type));
207 208
    PADDLE_ENFORCE_EQ(info->checker_,
                      nullptr,
Z
Zeng Jinle 已提交
209
                      platform::errors::AlreadyExists(
210
                          "OpAttrChecker of %s has been registered.", op_type));
211
    info->proto_ = new proto::OpProto;
212
    info->checker_ = new OpAttrChecker();
213
    info->proto_->set_type(op_type);
Y
Yu Yang 已提交
214
    T maker;
Y
yuyang18 已提交
215
    maker(info->proto_, info->checker_);
216
    PADDLE_ENFORCE_EQ(
217 218
        info->proto_->IsInitialized(),
        true,
219 220
        platform::errors::PreconditionNotMet(
            "Fail to initialize %s's OpProto, because %s is not initialized.",
221 222
            op_type,
            info->proto_->InitializationErrorString()));
223 224 225 226 227 228
  }
};

template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
229
    PADDLE_ENFORCE_EQ(
230 231
        info->grad_op_maker_,
        nullptr,
Z
Zeng Jinle 已提交
232 233 234
        platform::errors::AlreadyExists(
            "GradOpDescMaker of %s has been registered", op_type));

235 236 237 238 239 240 241 242
    info->grad_op_maker_ =
        [](const OpDesc& fwd_op,
           const std::unordered_set<std::string>& no_grad_set,
           std::unordered_map<std::string, std::string>* grad_to_var,
           const std::vector<BlockDesc*>& grad_block) {
          T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
          return maker();
        };
S
sneaxiy 已提交
243 244

    info->use_default_grad_op_desc_maker_ =
H
hong 已提交
245
        std::is_base_of<DefaultGradOpMaker<OpDesc, true>, T>::value ||
246 247 248 249 250 251 252 253 254
        std::is_base_of<DefaultGradOpMaker<OpDesc, false>, T>::value ||
        std::is_base_of<DefaultGradOpMaker<imperative::OpBase, true>,
                        T>::value ||
        std::is_base_of<DefaultGradOpMaker<imperative::OpBase, false>,
                        T>::value;

    info->use_empty_grad_op_desc_maker_ =
        std::is_base_of<EmptyGradOpMaker<OpDesc>, T>::value ||
        std::is_base_of<EmptyGradOpMaker<imperative::OpBase>, T>::value;
H
hong 已提交
255 256 257
  }
};

J
Jiabin Yang 已提交
258 259 260 261 262 263 264
template <typename T>
struct OpInfoFiller<T, kGradCompOpDescMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
    PADDLE_ENFORCE_EQ(
        info->grad_comp_op_maker_,
        nullptr,
        platform::errors::AlreadyExists(
265
            "CompositeGradOpMakerBase of %s has been registered", op_type));
J
Jiabin Yang 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281

    info->grad_comp_op_maker_ =
        [](const OpDesc& fwd_op,
           const std::unordered_set<std::string>& no_grad_set,
           std::unordered_map<std::string, std::string>* grad_to_var,
           const BlockDesc* current_block,
           const std::vector<BlockDesc*>& grad_block) {
          T maker(fwd_op, no_grad_set, grad_to_var, current_block, grad_block);
          return maker();
        };
    // TODO(jiabin): Support this later or just not.
    info->use_default_grad_op_desc_maker_ = false;
    info->use_empty_grad_op_desc_maker_ = false;
  }
};

H
hong 已提交
282 283 284
template <typename T>
struct OpInfoFiller<T, kGradOpBaseMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
285
    PADDLE_ENFORCE_EQ(
286 287
        info->dygraph_grad_op_maker_,
        nullptr,
Z
Zeng Jinle 已提交
288 289 290
        platform::errors::AlreadyExists(
            "GradOpBaseMaker of %s has been registered", op_type));

291 292 293 294 295 296 297 298 299 300 301
    info->dygraph_grad_op_maker_ =
        [](const std::string& type,
           const imperative::NameVarBaseMap& var_base_map_in,
           const imperative::NameVarBaseMap& var_base_map_out,
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs,
           const std::map<std::string, std::string>& inplace_map) {
          T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
          maker.SetDygraphDefaultAttrsMap(default_attrs);
          return maker();
        };
302 303
  }
};
Y
Yu Yang 已提交
304 305 306 307

template <typename T>
struct OpInfoFiller<T, kVarTypeInference> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
308
    PADDLE_ENFORCE_EQ(
309 310
        info->infer_var_type_,
        nullptr,
Z
Zeng Jinle 已提交
311 312
        platform::errors::AlreadyExists(
            "VarTypeInference of %s has been registered", op_type));
M
minqiyang 已提交
313
    info->infer_var_type_ = [](InferVarTypeContext* context) {
Y
Yu Yang 已提交
314
      T inference;
M
minqiyang 已提交
315
      inference(context);
Y
Yu Yang 已提交
316 317 318 319
    };
  }
};

320 321 322
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
  void operator()(const char* op_type, OpInfo* info) const {
323 324
    // Note: if fill InferShapeFN by this Filler, the infershape here
    // will overwrite the op->InferShape func registered in kOperator Filler
325 326 327 328 329 330 331
    info->infer_shape_ = [](InferShapeContext* ctx) {
      T inference;
      inference(ctx);
    };
  }
};

D
dzhwinter 已提交
332 333 334
template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
335
    PADDLE_ENFORCE_EQ(
336 337
        info->infer_inplace_,
        nullptr,
Z
Zeng Jinle 已提交
338 339
        platform::errors::AlreadyExists(
            "InplaceOpInference of %s has been registered", op_type));
340
    info->infer_inplace_ = [](bool use_cuda) {
D
dzhwinter 已提交
341
      T infer;
342
      return infer(use_cuda);
D
dzhwinter 已提交
343 344 345 346
    };
  }
};

S
sneaxiy 已提交
347 348 349
template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
350
    PADDLE_ENFORCE_EQ(
351 352
        info->infer_no_need_buffer_vars_,
        nullptr,
Z
Zeng Jinle 已提交
353 354
        platform::errors::AlreadyExists(
            "NoNeedBufferVarsInference of %s has been registered", op_type));
355
    info->infer_no_need_buffer_vars_.Reset(std::make_shared<T>());
S
sneaxiy 已提交
356 357 358
  }
};

359 360 361 362 363 364
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
  void operator()(const char* op_type, OpInfo* info) const {}
};

365 366 367 368
}  // namespace details

}  // namespace framework
}  // namespace paddle