operation.h 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/ir/builtin_attribute.h"
18
#include "paddle/ir/op_info.h"
19 20 21 22
#include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h"

namespace ir {
23 24 25 26
template <class ConcreteTrait>
class OpTraitBase;
template <typename ConcreteInterface>
class OpInterfaceBase;
27 28 29 30 31 32 33 34 35

class alignas(8) Operation final {
 public:
  ///
  /// \brief Malloc memory and construct objects in the following order:
  /// OpResultImpls|Operation|OpOperandImpls.
  ///
  static Operation *create(const std::vector<ir::OpResult> &inputs,
                           const std::vector<ir::Type> &output_types,
36 37
                           ir::DictionaryAttribute attribute,
                           ir::OpInfo op_info);
38 39 40 41 42 43 44

  void destroy();

  ir::OpResult GetResultByIndex(uint32_t index);

  std::string print();

45
  ir::DictionaryAttribute attribute() const { return attribute_; }
46

47
  ir::OpInfo op_info() const { return op_info_; }
48

49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
  uint32_t num_results() const { return num_results_; }

  uint32_t num_operands() const { return num_operands_; }

  template <typename T>
  T dyn_cast() const {
    return CastUtil<T>::call(this);
  }

  template <typename Trait>
  bool HasTrait() const {
    return op_info_.HasTrait<Trait>();
  }

  template <typename Interface>
  bool HasInterface() const {
    return op_info_.HasInterface<Interface>();
  }
67 68 69 70

 private:
  Operation(uint32_t num_results,
            uint32_t num_operands,
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
            ir::DictionaryAttribute attribute,
            ir::OpInfo op_info);

  template <typename T, typename Enabler = void>
  struct CastUtil {
    static T call(const Operation *op) {
      throw("Can't dyn_cast to T, T should be a Trait or Interface");
    }
  };
  template <typename T>
  struct CastUtil<T,
                  typename std::enable_if<
                      std::is_base_of<OpTraitBase<T>, T>::value>::type> {
    static T call(const Operation *op) { return T(op); }
  };
  template <typename T>
  struct CastUtil<T,
                  typename std::enable_if<
                      std::is_base_of<OpInterfaceBase<T>, T>::value>::type> {
    static T call(const Operation *op) {
      return T(op, op->op_info_.impl()->GetInterfaceImpl<T>());
    }
  };
94 95 96

  ir::DictionaryAttribute attribute_;

97 98
  ir::OpInfo op_info_;

99 100 101 102 103 104
  uint32_t num_results_ = 0;

  uint32_t num_operands_ = 0;
};

}  // namespace ir