builtin_op.cc 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
18
#include "paddle/ir/core/enforce.h"
19
#include "paddle/phi/core/enforce.h"
20 21

namespace ir {
22

23
const char *ModuleOp::attributes_name[attributes_num] = {"program"};  // NOLINT
24 25

Program *ModuleOp::program() {
26
  const AttributeMap &attr = this->attributes();
27 28 29 30 31 32 33 34 35
  auto iter = attr.find("program");
  if (iter == attr.end() || !iter->second) return nullptr;
  return static_cast<Program *>(
      iter->second.dyn_cast<PointerAttribute>().data());
}

Block *ModuleOp::block() {
  assert(operation() != nullptr);
  assert(operation()->num_regions() == 1);
36 37
  assert(operation()->region(0).size() == 1);
  return operation()->region(0).front();
38 39
}

40
ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
41 42 43
  ir::OpInfo info = context->GetRegisteredOpInfo(name());
  OperationArgument argument(info);
  argument.AddRegion()->emplace_back();
44
  argument.AddAttribute("program", PointerAttribute::get(context, pointer));
45
  return ModuleOp(Operation::Create(std::move(argument)));
46 47
}

48
void ModuleOp::Destroy() {
49
  if (operation()) {
50
    operation()->Destroy();
51 52 53 54
    *this = ModuleOp(nullptr);
  }
}

55
void ModuleOp::Verify() const {
56
  VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
57 58
  // Verify inputs:
  IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
59

60 61
  // Verify attributes:
  auto &attributes = this->attributes();
62
  auto iter = attributes.find("program");
63 64
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
             "Type of attribute: program is not right.");
65

66 67
  // Verify outputs:
  IR_ENFORCE(num_results() == 0u, "The size of inputs must be equal to 0.");
68 69
}

70
const char *GetParameterOp::attributes_name[attributes_num] = {  // NOLINT
71
    "parameter_name"};
72

73 74 75 76 77 78 79 80 81
void GetParameterOp::Build(Builder &builder,
                           OperationArgument &argument,
                           const std::string &name,
                           Type type) {
  argument.attributes[attributes_name[0]] =
      ir::StrAttribute::get(builder.ir_context(), name);
  argument.output_types.emplace_back(type);
}

82
void GetParameterOp::Verify() const {
83
  VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
84 85
  // Verify inputs:
  IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
86

87
  // Verify if attributes contain attribute name in attributes_name:
88
  auto &attributes = this->attributes();
89 90 91 92 93
  auto iter = attributes.find("parameter_name");
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
             "Type of attribute: parameter_name is not right.");

  // Verify outputs type:
94
  IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
95 96
}

97
const char *SetParameterOp::attributes_name[attributes_num] = {  // NOLINT
98
    "parameter_name"};
99

100 101 102 103 104 105 106 107
void SetParameterOp::Build(Builder &builder,             // NOLINT
                           OperationArgument &argument,  // NOLINT
                           OpResult parameter,
                           const std::string &name) {
  argument.AddOperand(parameter);
  argument.AddAttribute(attributes_name[0],
                        ir::StrAttribute::get(builder.ir_context(), name));
}
108
void SetParameterOp::Verify() const {
109
  VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
110 111
  // Verify inputs:
  IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
112

113 114
  // Verify attributes:
  auto &attributes = this->attributes();
115 116 117 118
  auto iter = attributes.find("parameter_name");
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
             "Type of attribute: parameter_name is not right.");

119 120
  // Verify outputs:
  IR_ENFORCE(num_results() == 0u, "The size of outputs must be equal to 0.");
121 122
}

123 124 125 126 127 128 129 130 131 132 133 134
void CombineOp::Build(Builder &builder,
                      OperationArgument &argument,
                      const std::vector<ir::OpResult> &inputs) {
  argument.inputs = inputs;
  std::vector<ir::Type> inputs_type(inputs.size());
  for (size_t idx = 0; idx < inputs.size(); ++idx) {
    inputs_type[idx] = inputs[idx].type();
  }
  argument.output_types.emplace_back(
      ir::VectorType::get(builder.ir_context(), inputs_type));
}

135
void CombineOp::Verify() const {
136
  // outputs.size() == 1
137 138 139 140 141 142
  IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");

  // output_type == Vector<Type>
  auto output_type = (*this)->result(0).type().dyn_cast<VectorType>();
  IR_ENFORCE(output_type,
             "The type of outputs[0] must be equal to VectorType.");
143

144
  // inputs.size() == outputs[0].size()
145 146 147
  auto input_num = num_operands();
  IR_ENFORCE(output_type.size() == input_num,
             "The size %d of output must be equal to size %d of inputs.",
148
             output_type.size(),
149
             input_num);
150 151

  // forall i in inputs.size(): inputs[i].type == outputs[0][i].type
152 153 154
  for (size_t i = 0; i < input_num; ++i) {
    auto type = (*this)->operand(i).type();
    IR_ENFORCE(output_type[i] == type,
155 156 157 158
               "The type %s of outputs[0][%d] must be "
               "equal to type %s of inputs[%d].",
               output_type[i],
               i,
159
               type,
160
               i);
161 162 163
  }
}

164
const char *SliceOp::attributes_name[attributes_num] = {"index"};  // NOLINT
165 166 167 168 169 170 171 172 173 174 175

void SliceOp::Build(Builder &builder,
                    OperationArgument &argument,
                    const ir::OpResult &input,
                    int index) {
  argument.inputs = {input};
  argument.output_types.emplace_back(input.type()
                                         .dyn_cast<ir::VectorType>()
                                         .data()[static_cast<size_t>(index)]);
}

176
void SliceOp::Verify() const {
177
  // inputs.size() == 1
178 179 180
  auto input_size = num_operands();
  IR_ENFORCE(
      input_size == 1, "The size %d of inputs must be equal to 1.", input_size);
181 182

  // inputs[0].type == Vector<Type>
183 184
  auto input_type = (*this)->operand(0).type().dyn_cast<ir::VectorType>();
  IR_ENFORCE(input_type,
185
             "The type %s of inputs[0] must be equal to VectorType.",
186
             input_type);
187

188
  auto output_size = num_results();
189
  // outputs.size() == 1
190
  IR_ENFORCE(output_size == 1,
191
             "The size %d of outputs must be equal to 1.",
192
             output_size);
193 194

  // attributes contains index: Int32
195
  auto &attributes = this->attributes();
196 197
  IR_ENFORCE(attributes.count("index") != 0,
             "The attributes must contains index.");
198
  const ir::Attribute &attr = attributes.at("index");
Z
zhangbo9674 已提交
199
  IR_ENFORCE(attr.isa<ir::Int32Attribute>(),
200
             "The attribute index must be INT32.");
Z
zhangbo9674 已提交
201
  auto index = attr.dyn_cast<ir::Int32Attribute>().data();
202 203

  // index >= 0 and < inputs[0].size()
204 205 206 207 208 209
  IR_ENFORCE(
      index >= 0, "The index %d must be greater or equal than 0.", index);
  IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
             "The index %d must be less or equal than size %d of inputs[0].",
             index,
             input_type.size());
210 211

  // inputs[index].type == outputs[0].type
212
  auto output_type = (*this)->result(0).type();
213
  IR_ENFORCE(
214
      input_type[index] == output_type,
215
      "The type %s of inputs[%d] must be equal to type %s of outputs[0].",
216
      input_type[index],
217
      index,
218
      output_type);
219 220
}

221 222 223 224 225 226 227 228 229 230 231
void SplitOp::Build(Builder &builder,
                    OperationArgument &argument,
                    const ir::OpResult &input) {
  argument.inputs = {input};
  for (size_t idx = 0; idx < input.type().dyn_cast<ir::VectorType>().size();
       ++idx) {
    argument.output_types.emplace_back(
        input.type().dyn_cast<ir::VectorType>().data()[idx]);
  }
}

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
void SplitOp::Verify() const {
  // inputs.size() == 1
  IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1.");

  // input_type == Vector<Type>
  auto input_type = (*this)->operand(0).type().dyn_cast<VectorType>();
  IR_ENFORCE(input_type, "The type of inputs[0] must be equal to VectorType.");

  // inputs[0].size() == outputs.size()
  auto output_num = num_results();
  IR_ENFORCE(input_type.size() == output_num,
             "The size %d of output must be equal to size %d of inputs.",
             output_num,
             input_type.size());

  // for all i in outputs.size(): outputs[i].type == inputs[0][i].type
  for (size_t i = 0; i < output_num; ++i) {
    auto type = (*this)->result(i).type();
    IR_ENFORCE(input_type[i] == type,
               "The type %s of inputs[0][%d] must be "
               "equal to type %s of outputs[%d].",
               input_type[i],
               i,
               type,
               i);
  }
}

260
const char *ConstantOp::attributes_name[attributes_num] = {"value"};  // NOLINT
261

262 263
void ConstantOp::Build(Builder &builder,
                       OperationArgument &argument,
264 265 266 267 268 269
                       Attribute value,
                       Type output_type) {
  argument.AddAttribute("value", value);
  argument.output_types.push_back(output_type);
}

270
void ConstantOp::Verify() const {
271 272 273
  IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
  IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
  IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");
K
kangguangli 已提交
274 275
}

276
Attribute ConstantOp::value() const { return attributes().at("value"); }
277

278
}  // namespace ir
279 280 281 282 283 284

IR_DEFINE_EXPLICIT_TYPE_ID(ir::ModuleOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp)
285
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SplitOp)
286 287
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp)