op_registry.h 12.6 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16

#pragma once

17
#include <map>
18
#include <memory>
19 20
#include <string>
#include <tuple>
S
sneaxiy 已提交
21
#include <type_traits>
M
minqiyang 已提交
22 23
#include <unordered_map>
#include <unordered_set>
24
#include <vector>
25

Y
Yi Wang 已提交
26
#include "paddle/fluid/framework/grad_op_desc_maker.h"
D
dzhwinter 已提交
27
#include "paddle/fluid/framework/inplace_op_inference.h"
S
sneaxiy 已提交
28
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
Y
Yi Wang 已提交
29 30 31 32
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type_inference.h"
H
hong 已提交
33 34
#include "paddle/fluid/imperative/dygraph_grad_maker.h"
#include "paddle/fluid/imperative/type_defs.h"
J
Jiabin Yang 已提交
35
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
36
#include "paddle/phi/core/macros.h"
37 38 39 40 41 42 43 44

namespace paddle {
namespace framework {
namespace details {

enum OpInfoFillType {
  kOperator = 0,
  kOpProtoAndCheckerMaker = 1,
Y
Yu Yang 已提交
45
  kGradOpDescMaker = 2,
46
  kVarTypeInference = 3,
D
dzhwinter 已提交
47
  kShapeInference = 4,
S
sneaxiy 已提交
48 49
  kInplaceOpInference = 5,
  kNoNeedBufferVarsInference = 6,
H
hong 已提交
50
  kGradOpBaseMaker = 7,
J
Jiabin Yang 已提交
51
  kGradCompOpDescMaker = 8,
S
sneaxiy 已提交
52
  kUnknown = -1
53 54
};

S
sneaxiy 已提交
55 56 57 58 59 60 61 62 63 64 65
namespace internal {
template <typename T, OpInfoFillType kType>
struct TypePair {
  using Type = T;
  static constexpr OpInfoFillType kFillType = kType;
};

using OpRegistryClasses = std::tuple<                                // NOLINT
    TypePair<OperatorBase, kOperator>,                               // NOLINT
    TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>,       // NOLINT
    TypePair<GradOpDescMakerBase, kGradOpDescMaker>,                 // NOLINT
H
hong 已提交
66
    TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>,     // NOLINT
67
    TypePair<prim::CompositeGradOpMakerBase, kGradCompOpDescMaker>,  // NOLINT
S
sneaxiy 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    TypePair<VarTypeInference, kVarTypeInference>,                   // NOLINT
    TypePair<InferShapeBase, kShapeInference>,                       // NOLINT
    TypePair<InplaceOpInference, kInplaceOpInference>,               // NOLINT
    TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference>  // NOLINT
    >;

static constexpr int kOpRegistryClassNumber =
    std::tuple_size<OpRegistryClasses>::value;

template <typename T, int kPos, bool kIsBounded /* = true*/>
struct IsMatchedBaseTypeImpl {
  using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type;
  static constexpr bool kValue =
      std::is_base_of<typename PairType::Type, T>::value;
};

template <typename T, int kPos>
struct IsMatchedBaseTypeImpl<T, kPos, false> {
  static constexpr bool kValue = false;
};

template <typename T, int kPos>
static inline constexpr bool IsMatchedBaseType() {
91 92 93 94
  return IsMatchedBaseTypeImpl<T,
                               kPos,
                               (kPos >= 0 &&
                                kPos < kOpRegistryClassNumber)>::kValue;
S
sneaxiy 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
}

template <typename T, int kStart, int kEnd, bool kIsEnd, bool kIsMatched>
struct OpInfoFillTypeGetterImpl {};

// This case should not happen
template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, true> {};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, false> {
  static constexpr OpInfoFillType kType = kUnknown;
};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> {
  static constexpr OpInfoFillType kType =
112 113 114 115
      OpInfoFillTypeGetterImpl<T,
                               kStart + 1,
                               kEnd,
                               kStart + 1 == kEnd,
S
sneaxiy 已提交
116 117 118 119 120 121 122 123 124 125 126
                               IsMatchedBaseType<T, kStart + 1>()>::kType;
};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> {
  using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type;
  static constexpr OpInfoFillType kType = PairType::kFillType;
};

template <typename T>
using OpInfoFillTypeGetter =
127 128 129
    OpInfoFillTypeGetterImpl<T,
                             0,
                             kOpRegistryClassNumber,
S
sneaxiy 已提交
130 131 132 133 134
                             kOpRegistryClassNumber == 0,
                             IsMatchedBaseType<T, 0>()>;

}  // namespace internal

135 136 137
template <typename T>
struct OpInfoFillTypeID {
  static constexpr OpInfoFillType ID() {
S
sneaxiy 已提交
138
    return internal::OpInfoFillTypeGetter<T>::kType;
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  }
};

template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
struct OpInfoFiller;

template <size_t I, bool at_end, typename... ARGS>
class OperatorRegistrarRecursive;

template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, false, ARGS...> {
 public:
  using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
  OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
    OpInfoFiller<T> fill;
    fill(op_type, info);
    constexpr auto size = sizeof...(ARGS);
    OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
                                                                  info);
    (void)(reg);
  }
};

template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, true, ARGS...> {
 public:
165
  OperatorRegistrarRecursive(const char* op_type UNUSED, OpInfo* info UNUSED) {}
166 167 168 169 170
};

template <typename T>
struct OpInfoFiller<T, kOperator> {
  void operator()(const char* op_type, OpInfo* info) const {
171 172
    PADDLE_ENFORCE_EQ(info->creator_,
                      nullptr,
Z
Zeng Jinle 已提交
173 174
                      platform::errors::AlreadyExists(
                          "OpCreator of %s has been registered", op_type));
175 176
    info->creator_ = [](const std::string& type,
                        const VariableNameMap& inputs,
177 178 179 180
                        const VariableNameMap& outputs,
                        const AttributeMap& attrs) {
      return new T(type, inputs, outputs, attrs);
    };
Z
Zeng Jinle 已提交
181 182 183

    if (std::is_base_of<OperatorWithKernel, T>::value) {
      PADDLE_ENFORCE_EQ(
184 185
          info->infer_shape_,
          nullptr,
Z
Zeng Jinle 已提交
186 187 188
          platform::errors::AlreadyExists(
              "Duplicate InferShapeFN of %s has been registered", op_type));

189 190
      OperatorWithKernel* op = dynamic_cast<OperatorWithKernel*>(info->creator_(
          std::string{}, VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
191 192 193
      PADDLE_ENFORCE_NOT_NULL(
          op,
          platform::errors::InvalidArgument("%s should have kernels", op_type));
Z
Zeng Jinle 已提交
194 195 196 197
      info->infer_shape_ = [op](InferShapeContext* ctx) {
        op->InferShape(ctx);
      };
    }
198 199 200 201 202 203
  }
};

template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
204 205
    PADDLE_ENFORCE_EQ(info->proto_,
                      nullptr,
Z
Zeng Jinle 已提交
206
                      platform::errors::AlreadyExists(
207
                          "OpProto of %s has been registered.", op_type));
208 209
    PADDLE_ENFORCE_EQ(info->checker_,
                      nullptr,
Z
Zeng Jinle 已提交
210
                      platform::errors::AlreadyExists(
211
                          "OpAttrChecker of %s has been registered.", op_type));
212
    info->proto_ = new proto::OpProto;
213
    info->checker_ = new OpAttrChecker();
214
    info->proto_->set_type(op_type);
Y
Yu Yang 已提交
215
    T maker;
Y
yuyang18 已提交
216
    maker(info->proto_, info->checker_);
217
    PADDLE_ENFORCE_EQ(
218 219
        info->proto_->IsInitialized(),
        true,
220 221
        platform::errors::PreconditionNotMet(
            "Fail to initialize %s's OpProto, because %s is not initialized.",
222 223
            op_type,
            info->proto_->InitializationErrorString()));
224 225 226 227 228 229
  }
};

template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
230
    PADDLE_ENFORCE_EQ(
231 232
        info->grad_op_maker_,
        nullptr,
Z
Zeng Jinle 已提交
233 234 235
        platform::errors::AlreadyExists(
            "GradOpDescMaker of %s has been registered", op_type));

236 237 238 239 240 241 242 243
    info->grad_op_maker_ =
        [](const OpDesc& fwd_op,
           const std::unordered_set<std::string>& no_grad_set,
           std::unordered_map<std::string, std::string>* grad_to_var,
           const std::vector<BlockDesc*>& grad_block) {
          T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
          return maker();
        };
S
sneaxiy 已提交
244 245

    info->use_default_grad_op_desc_maker_ =
H
hong 已提交
246
        std::is_base_of<DefaultGradOpMaker<OpDesc, true>, T>::value ||
247 248 249 250 251 252 253 254 255
        std::is_base_of<DefaultGradOpMaker<OpDesc, false>, T>::value ||
        std::is_base_of<DefaultGradOpMaker<imperative::OpBase, true>,
                        T>::value ||
        std::is_base_of<DefaultGradOpMaker<imperative::OpBase, false>,
                        T>::value;

    info->use_empty_grad_op_desc_maker_ =
        std::is_base_of<EmptyGradOpMaker<OpDesc>, T>::value ||
        std::is_base_of<EmptyGradOpMaker<imperative::OpBase>, T>::value;
H
hong 已提交
256 257 258
  }
};

J
Jiabin Yang 已提交
259 260 261 262 263 264 265
template <typename T>
struct OpInfoFiller<T, kGradCompOpDescMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
    PADDLE_ENFORCE_EQ(
        info->grad_comp_op_maker_,
        nullptr,
        platform::errors::AlreadyExists(
266
            "CompositeGradOpMakerBase of %s has been registered", op_type));
J
Jiabin Yang 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282

    info->grad_comp_op_maker_ =
        [](const OpDesc& fwd_op,
           const std::unordered_set<std::string>& no_grad_set,
           std::unordered_map<std::string, std::string>* grad_to_var,
           const BlockDesc* current_block,
           const std::vector<BlockDesc*>& grad_block) {
          T maker(fwd_op, no_grad_set, grad_to_var, current_block, grad_block);
          return maker();
        };
    // TODO(jiabin): Support this later or just not.
    info->use_default_grad_op_desc_maker_ = false;
    info->use_empty_grad_op_desc_maker_ = false;
  }
};

H
hong 已提交
283 284 285
template <typename T>
struct OpInfoFiller<T, kGradOpBaseMaker> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
286
    PADDLE_ENFORCE_EQ(
287 288
        info->dygraph_grad_op_maker_,
        nullptr,
Z
Zeng Jinle 已提交
289 290 291
        platform::errors::AlreadyExists(
            "GradOpBaseMaker of %s has been registered", op_type));

292 293 294 295 296 297 298 299 300 301 302
    info->dygraph_grad_op_maker_ =
        [](const std::string& type,
           const imperative::NameVarBaseMap& var_base_map_in,
           const imperative::NameVarBaseMap& var_base_map_out,
           const framework::AttributeMap& attrs,
           const framework::AttributeMap& default_attrs,
           const std::map<std::string, std::string>& inplace_map) {
          T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
          maker.SetDygraphDefaultAttrsMap(default_attrs);
          return maker();
        };
303 304
  }
};
Y
Yu Yang 已提交
305 306 307 308

template <typename T>
struct OpInfoFiller<T, kVarTypeInference> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
309
    PADDLE_ENFORCE_EQ(
310 311
        info->infer_var_type_,
        nullptr,
Z
Zeng Jinle 已提交
312 313
        platform::errors::AlreadyExists(
            "VarTypeInference of %s has been registered", op_type));
M
minqiyang 已提交
314
    info->infer_var_type_ = [](InferVarTypeContext* context) {
Y
Yu Yang 已提交
315
      T inference;
M
minqiyang 已提交
316
      inference(context);
Y
Yu Yang 已提交
317 318 319 320
    };
  }
};

321 322 323
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
  void operator()(const char* op_type, OpInfo* info) const {
324 325
    // Note: if fill InferShapeFN by this Filler, the infershape here
    // will overwrite the op->InferShape func registered in kOperator Filler
326 327 328 329 330 331 332
    info->infer_shape_ = [](InferShapeContext* ctx) {
      T inference;
      inference(ctx);
    };
  }
};

D
dzhwinter 已提交
333 334 335
template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
336
    PADDLE_ENFORCE_EQ(
337 338
        info->infer_inplace_,
        nullptr,
Z
Zeng Jinle 已提交
339 340
        platform::errors::AlreadyExists(
            "InplaceOpInference of %s has been registered", op_type));
341
    info->infer_inplace_ = [](bool use_cuda) {
D
dzhwinter 已提交
342
      T infer;
343
      return infer(use_cuda);
D
dzhwinter 已提交
344 345 346 347
    };
  }
};

S
sneaxiy 已提交
348 349 350
template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
  void operator()(const char* op_type, OpInfo* info) const {
Z
Zeng Jinle 已提交
351
    PADDLE_ENFORCE_EQ(
352 353
        info->infer_no_need_buffer_vars_,
        nullptr,
Z
Zeng Jinle 已提交
354 355
        platform::errors::AlreadyExists(
            "NoNeedBufferVarsInference of %s has been registered", op_type));
356
    info->infer_no_need_buffer_vars_.Reset(std::make_shared<T>());
S
sneaxiy 已提交
357 358 359
  }
};

360 361 362 363 364 365
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
  void operator()(const char* op_type, OpInfo* info) const {}
};

366 367 368 369
}  // namespace details

}  // namespace framework
}  // namespace paddle