op_meta_info.h 24.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <iostream>
18 19 20 21
#include <string>
#include <unordered_map>
#include <vector>

22 23 24 25
#include "paddle/pten/api/ext/dll_decl.h"
#include "paddle/pten/api/ext/exception.h"
#include "paddle/pten/api/include/tensor.h"
#include "paddle/utils/any.h"
26 27 28 29 30 31 32 33 34 35

/**
 * Op Meta Info Related Define.
 *
 * Used to maintain operator core information.
 *
 */

namespace paddle {
namespace framework {
36
class PADDLE_API OpMetaInfoHelper;
37 38 39 40
}  // namespace framework

using Tensor = paddle::Tensor;

41 42
///////////////// Util Marco Define ////////////////

43 44 45 46 47 48 49
#define PD_DISABLE_COPY_AND_ASSIGN(classname)      \
 private:                                          \
  classname(const classname&) = delete;            \
  classname(classname&&) = delete;                 \
  classname& operator=(const classname&) = delete; \
  classname& operator=(classname&&) = delete

50 51 52 53 54 55
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

56 57
///////////////// Util Define and Function ////////////////

58 59 60 61 62 63 64 65 66 67 68 69 70 71
constexpr char kGradTensorSuffix[] = "@GRAD";
constexpr char kTensorVectorSuffix[] = "@VECTOR";

// Used for Construct Grad Tensor name
inline std::string Grad(const std::string& t_name) {
  std::string result;
  result.reserve(t_name.size() + 5U);
  result += t_name;
  result += kGradTensorSuffix;
  return result;
}

// Used for Construct std::vector<Tensor> name
inline std::string Vec(const std::string& t_name) {
72
  std::string result;
73 74 75
  result.reserve(t_name.size() + 7U);
  result += t_name;
  result += kTensorVectorSuffix;
76 77 78 79 80 81
  return result;
}

////////////////////// Kernel Function (PD_KERNEL) ////////////////////////

// Record Op kernel core function
82 83 84
using KernelFunc =
    std::vector<Tensor> (*)(const std::vector<Tensor>& inputs,
                            const std::vector<std::vector<Tensor>>& vec_inputs,
85
                            const std::vector<paddle::any>& attrs);
86 87 88 89

#define PD_SPECIALIZE_ComputeCallHelper(attr_type)                            \
  template <typename... Tail>                                                 \
  struct ComputeCallHelper<attr_type, Tail...> {                              \
90 91 92
    template <int in_idx,                                                     \
              int vec_in_idx,                                                 \
              int attr_idx,                                                   \
93 94 95
              typename... PreviousArgs>                                       \
    static Return Compute(const std::vector<Tensor>& inputs,                  \
                          const std::vector<std::vector<Tensor>>& vec_inputs, \
96
                          const std::vector<paddle::any>& attrs,              \
97 98
                          const PreviousArgs&... pargs) {                     \
      try {                                                                   \
99
        attr_type arg = paddle::any_cast<attr_type>(attrs[attr_idx]);         \
100 101 102 103
        return ComputeCallHelper<Tail...>::template Compute<in_idx,           \
                                                            vec_in_idx,       \
                                                            attr_idx + 1>(    \
            inputs, vec_inputs, attrs, pargs..., arg);                        \
104
      } catch (paddle::bad_any_cast&) {                                       \
105 106 107 108 109
        PD_THROW(                                                             \
            "Attribute cast error in custom operator. Expected " #attr_type   \
            " value.");                                                       \
      }                                                                       \
    }                                                                         \
110 111
  }

112 113 114 115 116 117 118 119
template <typename T>
struct TypeTag {};

template <typename F, F f>
struct KernelFuncImpl;

template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
120 121
  static Return Compute(const std::vector<Tensor>& inputs,
                        const std::vector<std::vector<Tensor>>& vec_inputs,
122
                        const std::vector<paddle::any>& attrs) {
123 124
    return ComputeCallHelper<Args..., TypeTag<int>>::template Compute<0, 0, 0>(
        inputs, vec_inputs, attrs);
125 126 127 128 129 130 131 132
  }

 private:
  template <typename... RemainingArgs>
  struct ComputeCallHelper;

  template <typename... Tail>
  struct ComputeCallHelper<const Tensor&, Tail...> {
133 134 135
    template <int in_idx,
              int vec_in_idx,
              int attr_idx,
136
              typename... PreviousArgs>
137 138
    static Return Compute(const std::vector<Tensor>& inputs,
                          const std::vector<std::vector<Tensor>>& vec_inputs,
139
                          const std::vector<paddle::any>& attrs,
140 141
                          const PreviousArgs&... pargs) {
      const Tensor& arg = inputs[in_idx];
142
      return ComputeCallHelper<Tail...>::template Compute<in_idx + 1,
143 144
                                                          vec_in_idx,
                                                          attr_idx>(
145 146 147 148 149 150
          inputs, vec_inputs, attrs, pargs..., arg);
    }
  };

  template <typename... Tail>
  struct ComputeCallHelper<const std::vector<Tensor>&, Tail...> {
151 152 153
    template <int in_idx,
              int vec_in_idx,
              int attr_idx,
154
              typename... PreviousArgs>
155 156
    static Return Compute(const std::vector<Tensor>& inputs,
                          const std::vector<std::vector<Tensor>>& vec_inputs,
157
                          const std::vector<paddle::any>& attrs,
158 159
                          const PreviousArgs&... pargs) {
      const std::vector<Tensor>& arg = vec_inputs[vec_in_idx];
160 161 162 163
      return ComputeCallHelper<Tail...>::template Compute<in_idx,
                                                          vec_in_idx + 1,
                                                          attr_idx>(
          inputs, vec_inputs, attrs, pargs..., arg);
164 165 166
    }
  };

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
  PD_SPECIALIZE_ComputeCallHelper(const bool&);
  PD_SPECIALIZE_ComputeCallHelper(const int&);
  PD_SPECIALIZE_ComputeCallHelper(const float&);
  PD_SPECIALIZE_ComputeCallHelper(const int64_t&);
  PD_SPECIALIZE_ComputeCallHelper(const std::string&);
  PD_SPECIALIZE_ComputeCallHelper(const std::vector<int>&);
  PD_SPECIALIZE_ComputeCallHelper(const std::vector<float>&);
  PD_SPECIALIZE_ComputeCallHelper(const std::vector<int64_t>&);
  PD_SPECIALIZE_ComputeCallHelper(const std::vector<std::string>&);
  // TODO(chenweihang): support other attribute type if needed.
  // Why not support other attribute type here?
  // - boost::blank, std::vector<bool> and std::vector<double>
  //   are not used in op
  // - BlockDesc* and std::vector<BlockDesc*> are used in framework

  // NOTE(chenweihang): Used to be compatible with the 2.0.1 released
  // interface, and will be deprecated in the future
184 185 186 187 188 189 190 191 192
  PD_SPECIALIZE_ComputeCallHelper(bool);
  PD_SPECIALIZE_ComputeCallHelper(int);
  PD_SPECIALIZE_ComputeCallHelper(float);
  PD_SPECIALIZE_ComputeCallHelper(int64_t);
  PD_SPECIALIZE_ComputeCallHelper(std::string);
  PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
  PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
  PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
  PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
193

194 195 196
  // end: base template
  template <typename T>
  struct ComputeCallHelper<TypeTag<T>> {
197
    template <int in_idx, int vec_in_idx, int attr_idx>
198 199
    static Return Compute(const std::vector<Tensor>& inputs,
                          const std::vector<std::vector<Tensor>>& vec_inputs,
200
                          const std::vector<paddle::any>& attrs,
201
                          const Args&... args) {
202 203 204 205 206 207 208 209 210 211 212 213
      return impl_fn(args...);
    }
  };
};

#define PD_KERNEL(...) \
  ::paddle::KernelFuncImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute

/////////////// InferShape Function (PD_INFER_SHAPE) ///////////////

// Record Op infershape core function
using InferShapeFunc = std::vector<std::vector<int64_t>> (*)(
214
    const std::vector<std::vector<int64_t>>& input_shapes,
215
    const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
216 217
    const std::vector<paddle::any>& attrs);

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type)            \
  template <typename... Tail>                                               \
  struct InferShapeCallHelper<input_type, Tail...> {                        \
    template <int in_idx,                                                   \
              int vec_in_idx,                                               \
              int attr_idx,                                                 \
              typename... PreviousArgs>                                     \
    static Return InferShape(                                               \
        const std::vector<std::vector<int64_t>>& input_shapes,              \
        const std::vector<std::vector<std::vector<int64_t>>>&               \
            vec_input_shapes,                                               \
        const std::vector<paddle::any>& attrs,                              \
        const PreviousArgs&... pargs) {                                     \
      input_type arg = input_shapes[in_idx];                                \
      return InferShapeCallHelper<Tail...>::template InferShape<in_idx + 1, \
                                                                vec_in_idx, \
                                                                attr_idx>(  \
          input_shapes, vec_input_shapes, attrs, pargs..., arg);            \
    }                                                                       \
237 238
  }

239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type)    \
  template <typename... Tail>                                        \
  struct InferShapeCallHelper<input_type, Tail...> {                 \
    template <int in_idx,                                            \
              int vec_in_idx,                                        \
              int attr_idx,                                          \
              typename... PreviousArgs>                              \
    static Return InferShape(                                        \
        const std::vector<std::vector<int64_t>>& input_shapes,       \
        const std::vector<std::vector<std::vector<int64_t>>>&        \
            vec_input_shapes,                                        \
        const std::vector<paddle::any>& attrs,                       \
        const PreviousArgs&... pargs) {                              \
      input_type arg = vec_input_shapes[vec_in_idx];                 \
      return InferShapeCallHelper<Tail...>::                         \
          template InferShape<in_idx, vec_in_idx + 1, attr_idx>(     \
              input_shapes, vec_input_shapes, attrs, pargs..., arg); \
    }                                                                \
257 258
  }

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(attr_type)               \
  template <typename... Tail>                                                \
  struct InferShapeCallHelper<attr_type, Tail...> {                          \
    template <int in_idx,                                                    \
              int vec_in_idx,                                                \
              int attr_idx,                                                  \
              typename... PreviousArgs>                                      \
    static Return InferShape(                                                \
        const std::vector<std::vector<int64_t>>& input_shapes,               \
        const std::vector<std::vector<std::vector<int64_t>>>&                \
            vec_input_shapes,                                                \
        const std::vector<paddle::any>& attrs,                               \
        const PreviousArgs&... pargs) {                                      \
      try {                                                                  \
        attr_type arg = paddle::any_cast<attr_type>(attrs[attr_idx]);        \
        return InferShapeCallHelper<Tail...>::                               \
            template InferShape<in_idx, vec_in_idx, attr_idx + 1>(           \
                input_shapes, vec_input_shapes, attrs, pargs..., arg);       \
      } catch (paddle::bad_any_cast&) {                                      \
        PD_THROW(                                                            \
            "Attribute cast error in custom operator InferShapeFn. "         \
            "Expected " #attr_type                                           \
            " value. InferShapeFn's attribute list must be exactly same as " \
            "Forward "                                                       \
            "KernelFn's attribute list except std::vector<int64_t> "         \
            "attribute.");                                                   \
      }                                                                      \
    }                                                                        \
287
  }
288 289 290 291 292 293

template <typename F, F f>
struct InferShapeFuncImpl;

template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
294
  static Return InferShape(
295
      const std::vector<std::vector<int64_t>>& input_shapes,
296
      const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
297
      const std::vector<paddle::any>& attrs) {
298 299 300 301
    return InferShapeCallHelper<Args..., TypeTag<int>>::template InferShape<0,
                                                                            0,
                                                                            0>(
        input_shapes, vec_input_shapes, attrs);
302 303 304 305 306 307
  }

 private:
  template <typename... RemainingArgs>
  struct InferShapeCallHelper;

308 309 310
  PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(const std::vector<int64_t>&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(
      const std::vector<std::vector<int64_t>>&);
311

312 313 314 315 316
  // NOTE(chenweihang): Used to be compatible with the 2.0.1 released
  // interface, and will be deprecated in the future
  PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(std::vector<int64_t>);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(
      std::vector<std::vector<int64_t>>);
317

318 319 320 321 322 323 324 325 326 327 328 329
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const bool&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::string&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<int>&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<float>&);
  PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<std::string>&);
  // NOTE(chenweihang): InferShape can't support std::vector<int64_t> attr type,
  // because the input type is std::vector<int64_t>, only can use one rule to
  // parse std::vector<int64_t> parameter

330 331 332
  // end: base template
  template <typename T>
  struct InferShapeCallHelper<TypeTag<T>> {
333
    template <int in_idx, int vec_in_idx, int attr_idx>
334
    static Return InferShape(
335 336
        const std::vector<std::vector<int64_t>>& input_shapes,
        const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
337 338
        const std::vector<paddle::any>& attrs,
        const Args&... args) {
339 340 341 342 343 344 345 346 347 348 349
      return impl_fn(args...);
    }
  };
};

#define PD_INFER_SHAPE(...) \
  ::paddle::InferShapeFuncImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::InferShape

/////////////// InferDataType Function (PD_INFER_DTYPE) ///////////////

// Record Op Infer dtype core function
350
using InferDtypeFunc = std::vector<DataType> (*)(
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
    const std::vector<DataType>& input_dtypes,
    const std::vector<std::vector<DataType>>& vec_input_dtypes);

#define PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(input_type)              \
  template <typename... Tail>                                                \
  struct InferDtypeCallHelper<input_type, Tail...> {                         \
    template <int in_idx, int vec_in_idx, typename... PreviousArgs>          \
    static Return InferDtype(                                                \
        const std::vector<DataType>& input_dtypes,                           \
        const std::vector<std::vector<DataType>>& vec_input_dtypes,          \
        const PreviousArgs&... pargs) {                                      \
      input_type arg = input_dtypes[in_idx];                                 \
      return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx + 1,  \
                                                                vec_in_idx>( \
          input_dtypes, vec_input_dtypes, pargs..., arg);                    \
    }                                                                        \
  }

369 370 371 372 373 374 375 376 377 378 379 380 381 382
#define PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(input_type)            \
  template <typename... Tail>                                                \
  struct InferDtypeCallHelper<input_type, Tail...> {                         \
    template <int in_idx, int vec_in_idx, typename... PreviousArgs>          \
    static Return InferDtype(                                                \
        const std::vector<DataType>& input_dtypes,                           \
        const std::vector<std::vector<DataType>>& vec_input_dtypes,          \
        const PreviousArgs&... pargs) {                                      \
      input_type arg = vec_input_dtypes[vec_in_idx];                         \
      return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx,      \
                                                                vec_in_idx + \
                                                                    1>(      \
          input_dtypes, vec_input_dtypes, pargs..., arg);                    \
    }                                                                        \
383
  }
384 385 386 387 388 389

template <typename F, F f>
struct InferDtypeFuncImpl;

template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
390
  static Return InferDtype(
391 392
      const std::vector<DataType>& input_dtypes,
      const std::vector<std::vector<DataType>>& vec_input_dtypes) {
393 394 395
    return InferDtypeCallHelper<Args..., TypeTag<int>>::template InferDtype<0,
                                                                            0>(
        input_dtypes, vec_input_dtypes);
396 397 398 399 400 401
  }

 private:
  template <typename... RemainingArgs>
  struct InferDtypeCallHelper;

402 403
  PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(const DataType&);
  PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(const std::vector<DataType>&);
404

405 406 407 408
  // NOTE(chenweihang): Used to be compatible with the 2.0.1 released
  // interface, and will be deprecated in the future
  PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(DataType);
  PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(std::vector<DataType>);
409 410 411 412

  // end: base template
  template <typename T>
  struct InferDtypeCallHelper<TypeTag<T>> {
413 414
    template <int in_idx, int vec_in_idx>
    static Return InferDtype(
415 416
        const std::vector<DataType>& input_dtypes,
        const std::vector<std::vector<DataType>>& vec_input_dtypes,
417
        const Args&... args) {
418 419 420 421 422 423 424 425 426 427
      return impl_fn(args...);
    }
  };
};

#define PD_INFER_DTYPE(...) \
  ::paddle::InferDtypeFuncImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::InferDtype

////////////////////// Op Meta Info //////////////////////

428
class PADDLE_API OpMetaInfo {
429 430
 public:
  explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}
431 432

  // format: {"<name1>", "<name2>", ...}
433
  OpMetaInfo& Inputs(std::vector<std::string>&& inputs);
434 435

  // format: {"<name1>", "<name2>", ...}
436
  OpMetaInfo& Outputs(std::vector<std::string>&& outputs);
437 438 439 440 441

  // format: {"<name1>:<type1>", "<name1>:<type1>", ...}
  OpMetaInfo& Attrs(std::vector<std::string>&& attrs);

  // format: PD_KERNEL(...)
442
  OpMetaInfo& SetKernelFn(KernelFunc&& func);
443 444

  // format: PD_INFER_SHAPE(...)
445
  OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func);
446 447

  // format: PD_INFER_DTYPE(...)
448 449 450 451 452 453 454 455 456 457 458 459
  OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);

 private:
  friend class framework::OpMetaInfoHelper;

  // 1. desc info
  std::string name_;
  std::vector<std::string> inputs_;
  std::vector<std::string> outputs_;
  std::vector<std::string> attrs_;

  // 2. func info
460 461 462
  KernelFunc kernel_fn_{nullptr};
  InferShapeFunc infer_shape_fn_{nullptr};
  InferDtypeFunc infer_dtype_fn_{nullptr};
463 464 465 466
};

//////////////// Op Meta Info Map /////////////////

467
class PADDLE_API OpMetaInfoMap {
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
 public:
  // this function's impl should keep in header file.
  // if move to cc file, meta info can not be added
  // into map
  static OpMetaInfoMap& Instance() {
    static OpMetaInfoMap g_custom_op_meta_info_map;
    return g_custom_op_meta_info_map;
  }

  std::vector<OpMetaInfo>& operator[](const std::string& name);

  const std::unordered_map<std::string, std::vector<OpMetaInfo>>& GetMap()
      const;

 private:
  OpMetaInfoMap() = default;
  std::unordered_map<std::string, std::vector<OpMetaInfo>> map_;

  PD_DISABLE_COPY_AND_ASSIGN(OpMetaInfoMap);
};

//////////////// Op Meta Info Builder /////////////////

491
class PADDLE_API OpMetaInfoBuilder {
492
 public:
493
  explicit OpMetaInfoBuilder(std::string&& name, size_t index);
494 495
  OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
  OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
496
  OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
497 498 499
  OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
  OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
  OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
500 501 502 503

 private:
  // Forward Op name
  std::string name_;
504
  // ref current info ptr
505
  OpMetaInfo* info_ptr_;
506 507 508
  // The current op meta info index in vector
  // - 0: op, 1: grad_op, 2: grad_grad_op
  size_t index_;
509 510 511 512 513
};

/////////////////////// Op register API /////////////////////////

// For inference: compile directly with framework
514
// Call after PD_BUILD_OP(...)
515 516
void RegisterAllCustomOperator();

517 518 519 520 521
// Using this api to load compiled custom operator's dynamic library and
// register Custom
// Operator into it
void LoadCustomOperatorLib(const std::string& dso_name);

522 523
/////////////////////// Op register Macro /////////////////////////

524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
#define PD_BUILD_OP(op_name)                                                   \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                              \
      __reg_op__##op_name, "PD_BUILD_OP must be called in global namespace."); \
  static ::paddle::OpMetaInfoBuilder __op_meta_info_##op_name##__ =            \
      ::paddle::OpMetaInfoBuilder(#op_name, 0)

#define PD_BUILD_GRAD_OP(op_name)                                        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                        \
      __reg_grad_op__##op_name,                                          \
      "PD_BUILD_GRAD_OP must be called in global namespace.");           \
  static ::paddle::OpMetaInfoBuilder __grad_op_meta_info_##op_name##__ = \
      ::paddle::OpMetaInfoBuilder(#op_name, 1)

#define PD_BUILD_DOUBLE_GRAD_OP(op_name)                                      \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                             \
      __reg_grad_grad_op__##op_name,                                          \
      "PD_BUILD_DOUBLE_GRAD_OP must be called in global namespace.");         \
  static ::paddle::OpMetaInfoBuilder __grad_grad_op_meta_info_##op_name##__ = \
      ::paddle::OpMetaInfoBuilder(#op_name, 2)
543

544 545 546 547 548 549 550 551
}  // namespace paddle

///////////////////// C API ///////////////////

#ifdef __cplusplus
extern "C" {
#endif

552
#if defined(_WIN32)
553
// C-API to get global OpMetaInfoMap.
554 555 556 557
__declspec(dllexport) inline paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() {
  return paddle::OpMetaInfoMap::Instance();
}
#endif  // _WIN32
558 559 560 561

#ifdef __cplusplus
}
#endif