builtin_op.cc 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
18
#include "paddle/ir/core/enforce.h"
19
#include "paddle/phi/core/enforce.h"
20 21

namespace ir {
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

const char *ModuleOp::attributes_name[attributes_num] = {"program"};

Program *ModuleOp::program() {
  const AttributeMap &attr = operation()->attributes();
  auto iter = attr.find("program");
  if (iter == attr.end() || !iter->second) return nullptr;
  return static_cast<Program *>(
      iter->second.dyn_cast<PointerAttribute>().data());
}

Block *ModuleOp::block() {
  assert(operation() != nullptr);
  assert(operation()->num_regions() == 1);
  assert(operation()->GetRegion(0).size() == 1);
  return operation()->GetRegion(0).front();
}

40
ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
41 42 43
  ir::OpInfo info = context->GetRegisteredOpInfo(name());
  OperationArgument argument(info);
  argument.AddRegion()->emplace_back();
44
  argument.AddAttribute("program", PointerAttribute::get(context, pointer));
45
  return ModuleOp(Operation::Create(std::move(argument)));
46 47
}

48
void ModuleOp::Destroy() {
49
  if (operation()) {
50
    operation()->Destroy();
51 52 53 54
    *this = ModuleOp(nullptr);
  }
}

55
void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
56 57 58 59
                      const std::vector<ir::Type> &outputs,
                      const ir::AttributeMap &attributes) {
  VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
  // Verify inputs type:
60
  IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
61 62 63

  // Verify if attributes contain attribute name in attributes_name:
  auto iter = attributes.find("program");
64 65
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
             "Type of attribute: program is not right.");
66 67

  // Verify outputs type:
68
  IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
69 70
}

71 72
const char *GetParameterOp::attributes_name[attributes_num] = {
    "parameter_name"};
73

74
void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
75 76 77 78
                            const std::vector<ir::Type> &outputs,
                            const ir::AttributeMap &attributes) {
  VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
  // Verify inputs type:
79 80
  IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");

81
  // Verify if attributes contain attribute name in attributes_name:
82 83 84 85 86 87
  auto iter = attributes.find("parameter_name");
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
             "Type of attribute: parameter_name is not right.");

  // Verify outputs type:
  IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
88 89
}

90 91
const char *SetParameterOp::attributes_name[attributes_num] = {
    "parameter_name"};
92

93
void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
94 95 96 97
                            const std::vector<ir::Type> &outputs,
                            const ir::AttributeMap &attributes) {
  VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
  // Verify inputs type:
98 99
  IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");

100
  // Verify if attributes contain attribute name in attributes_name:
101 102 103 104 105 106
  auto iter = attributes.find("parameter_name");
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
             "Type of attribute: parameter_name is not right.");

  // Verify outputs type:
  IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
107 108
}

109
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
110 111 112
                       const std::vector<ir::Type> &outputs,
                       const ir::AttributeMap &attributes) {
  // outputs.size() == 1
113 114 115 116
  IR_ENFORCE(outputs.size() == 1,
             "The size %d of outputs must be equal to 1.",
             outputs.size());

117
  // outputs[0].type == Vector<Type>
118 119 120
  IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
             "The type %s of outputs[0] must be equal to VectorType.",
             outputs[0]);
121 122
  ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
  // inputs.size() == outputs[0].size()
123 124 125 126
  IR_ENFORCE(output_type.size() == inputs.size(),
             "The size %d of outputs[0] must be equal to size %d of inputs.",
             output_type.size(),
             inputs.size());
127 128 129

  // forall i in inputs.size(): inputs[i].type == outputs[0][i].type
  for (size_t i = 0; i < inputs.size(); i++) {
130 131 132 133 134 135 136
    IR_ENFORCE(output_type[i] == inputs[i].type(),
               "The type %s of outputs[0][%d] must be "
               "equal to type %s of inputs[%d].",
               output_type[i],
               i,
               inputs[i].type(),
               i);
137 138 139 140
  }
}

const char *SliceOp::attributes_name[attributes_num] = {"index"};
141
void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
142 143 144
                     const std::vector<ir::Type> &outputs,
                     const ir::AttributeMap &attributes) {
  // inputs.size() == 1
145 146 147
  IR_ENFORCE(inputs.size() == 1,
             "The size %d of inputs must be equal to 1.",
             inputs.size());
148 149

  // inputs[0].type == Vector<Type>
150 151 152
  IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
             "The type %s of inputs[0] must be equal to VectorType.",
             inputs[0].type());
153 154 155
  ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();

  // outputs.size() == 1
156 157 158
  IR_ENFORCE(outputs.size() == 1,
             "The size %d of outputs must be equal to 1.",
             outputs.size());
159 160

  // attributes contains index: Int32
161 162
  IR_ENFORCE(attributes.count("index") != 0,
             "The attributes must contains index.");
163
  const ir::Attribute &attr = attributes.at("index");
Z
zhangbo9674 已提交
164
  IR_ENFORCE(attr.isa<ir::Int32Attribute>(),
165
             "The attribute index must be INT32.");
Z
zhangbo9674 已提交
166
  auto index = attr.dyn_cast<ir::Int32Attribute>().data();
167 168

  // index >= 0 and < inputs[0].size()
169 170 171 172 173 174
  IR_ENFORCE(
      index >= 0, "The index %d must be greater or equal than 0.", index);
  IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
             "The index %d must be less or equal than size %d of inputs[0].",
             index,
             input_type.size());
175 176

  // inputs[index].type == outputs[0].type
177 178 179
  IR_ENFORCE(
      input_type[index] == outputs[0],
      "The type %s of inputs[%d] must be equal to type %s of outputs[0].",
180
      input_type[index],
181 182
      index,
      outputs[0]);
183 184
}

185 186
const char *ConstantOp::attributes_name[attributes_num] = {"value"};

187 188
void ConstantOp::Build(Builder &builder,
                       OperationArgument &argument,
189 190 191 192 193 194 195
                       Attribute value,
                       Type output_type) {
  argument.AddAttribute("value", value);
  argument.output_types.push_back(output_type);
}

void ConstantOp::Verify(const std::vector<ir::OpResult> &inputs,
K
kangguangli 已提交
196 197
                        const std::vector<ir::Type> &outputs,
                        const ir::AttributeMap &attributes) {
198 199 200 201
  IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
  IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
  IR_ENFORCE(attributes.count("value") > 0,
             "Type of attribute: value is not right.");
K
kangguangli 已提交
202 203
}

204 205
Attribute ConstantOp::value() { return operation()->attributes().at("value"); }

206
}  // namespace ir
207 208 209 210 211 212 213 214

IR_DEFINE_EXPLICIT_TYPE_ID(ir::ModuleOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp)