builtin_op.cc 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
18
#include "paddle/ir/core/enforce.h"
19
#include "paddle/phi/core/enforce.h"
20 21

namespace ir {
22 23 24 25 26 27 28 29 30 31 32 33 34 35

const char *ModuleOp::attributes_name[attributes_num] = {"program"};

Program *ModuleOp::program() {
  const AttributeMap &attr = operation()->attributes();
  auto iter = attr.find("program");
  if (iter == attr.end() || !iter->second) return nullptr;
  return static_cast<Program *>(
      iter->second.dyn_cast<PointerAttribute>().data());
}

Block *ModuleOp::block() {
  assert(operation() != nullptr);
  assert(operation()->num_regions() == 1);
36 37
  assert(operation()->region(0).size() == 1);
  return operation()->region(0).front();
38 39
}

40
ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
41 42 43
  ir::OpInfo info = context->GetRegisteredOpInfo(name());
  OperationArgument argument(info);
  argument.AddRegion()->emplace_back();
44
  argument.AddAttribute("program", PointerAttribute::get(context, pointer));
45
  return ModuleOp(Operation::Create(std::move(argument)));
46 47
}

48
void ModuleOp::Destroy() {
49
  if (operation()) {
50
    operation()->Destroy();
51 52 53 54
    *this = ModuleOp(nullptr);
  }
}

55
void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
56 57 58 59
                      const std::vector<ir::Type> &outputs,
                      const ir::AttributeMap &attributes) {
  VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
  // Verify inputs type:
60
  IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
61 62 63

  // Verify if attributes contain attribute name in attributes_name:
  auto iter = attributes.find("program");
64 65
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
             "Type of attribute: program is not right.");
66 67

  // Verify outputs type:
68
  IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
69 70
}

71 72
const char *GetParameterOp::attributes_name[attributes_num] = {
    "parameter_name"};
73

74 75 76 77 78 79 80 81 82
void GetParameterOp::Build(Builder &builder,
                           OperationArgument &argument,
                           const std::string &name,
                           Type type) {
  argument.attributes[attributes_name[0]] =
      ir::StrAttribute::get(builder.ir_context(), name);
  argument.output_types.emplace_back(type);
}

83
void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
84 85 86 87
                            const std::vector<ir::Type> &outputs,
                            const ir::AttributeMap &attributes) {
  VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
  // Verify inputs type:
88 89
  IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");

90
  // Verify if attributes contain attribute name in attributes_name:
91 92 93 94 95 96
  auto iter = attributes.find("parameter_name");
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
             "Type of attribute: parameter_name is not right.");

  // Verify outputs type:
  IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
97 98
}

99 100
const char *SetParameterOp::attributes_name[attributes_num] = {
    "parameter_name"};
101

102 103 104 105 106 107 108 109
void SetParameterOp::Build(Builder &builder,             // NOLINT
                           OperationArgument &argument,  // NOLINT
                           OpResult parameter,
                           const std::string &name) {
  argument.AddOperand(parameter);
  argument.AddAttribute(attributes_name[0],
                        ir::StrAttribute::get(builder.ir_context(), name));
}
110
void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
111 112 113 114
                            const std::vector<ir::Type> &outputs,
                            const ir::AttributeMap &attributes) {
  VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
  // Verify inputs type:
115 116
  IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");

117
  // Verify if attributes contain attribute name in attributes_name:
118 119 120 121 122 123
  auto iter = attributes.find("parameter_name");
  IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
             "Type of attribute: parameter_name is not right.");

  // Verify outputs type:
  IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
124 125
}

126 127 128 129 130 131 132 133 134 135 136 137
void CombineOp::Build(Builder &builder,
                      OperationArgument &argument,
                      const std::vector<ir::OpResult> &inputs) {
  argument.inputs = inputs;
  std::vector<ir::Type> inputs_type(inputs.size());
  for (size_t idx = 0; idx < inputs.size(); ++idx) {
    inputs_type[idx] = inputs[idx].type();
  }
  argument.output_types.emplace_back(
      ir::VectorType::get(builder.ir_context(), inputs_type));
}

138
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
139 140 141
                       const std::vector<ir::Type> &outputs,
                       const ir::AttributeMap &attributes) {
  // outputs.size() == 1
142 143 144 145
  IR_ENFORCE(outputs.size() == 1,
             "The size %d of outputs must be equal to 1.",
             outputs.size());

146
  // outputs[0].type == Vector<Type>
147 148 149
  IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
             "The type %s of outputs[0] must be equal to VectorType.",
             outputs[0]);
150 151
  ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
  // inputs.size() == outputs[0].size()
152 153 154 155
  IR_ENFORCE(output_type.size() == inputs.size(),
             "The size %d of outputs[0] must be equal to size %d of inputs.",
             output_type.size(),
             inputs.size());
156 157 158

  // forall i in inputs.size(): inputs[i].type == outputs[0][i].type
  for (size_t i = 0; i < inputs.size(); i++) {
159 160 161 162 163 164 165
    IR_ENFORCE(output_type[i] == inputs[i].type(),
               "The type %s of outputs[0][%d] must be "
               "equal to type %s of inputs[%d].",
               output_type[i],
               i,
               inputs[i].type(),
               i);
166 167 168 169
  }
}

const char *SliceOp::attributes_name[attributes_num] = {"index"};
170
void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
171 172 173
                     const std::vector<ir::Type> &outputs,
                     const ir::AttributeMap &attributes) {
  // inputs.size() == 1
174 175 176
  IR_ENFORCE(inputs.size() == 1,
             "The size %d of inputs must be equal to 1.",
             inputs.size());
177 178

  // inputs[0].type == Vector<Type>
179 180 181
  IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
             "The type %s of inputs[0] must be equal to VectorType.",
             inputs[0].type());
182 183 184
  ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();

  // outputs.size() == 1
185 186 187
  IR_ENFORCE(outputs.size() == 1,
             "The size %d of outputs must be equal to 1.",
             outputs.size());
188 189

  // attributes contains index: Int32
190 191
  IR_ENFORCE(attributes.count("index") != 0,
             "The attributes must contains index.");
192
  const ir::Attribute &attr = attributes.at("index");
Z
zhangbo9674 已提交
193
  IR_ENFORCE(attr.isa<ir::Int32Attribute>(),
194
             "The attribute index must be INT32.");
Z
zhangbo9674 已提交
195
  auto index = attr.dyn_cast<ir::Int32Attribute>().data();
196 197

  // index >= 0 and < inputs[0].size()
198 199 200 201 202 203
  IR_ENFORCE(
      index >= 0, "The index %d must be greater or equal than 0.", index);
  IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
             "The index %d must be less or equal than size %d of inputs[0].",
             index,
             input_type.size());
204 205

  // inputs[index].type == outputs[0].type
206 207 208
  IR_ENFORCE(
      input_type[index] == outputs[0],
      "The type %s of inputs[%d] must be equal to type %s of outputs[0].",
209
      input_type[index],
210 211
      index,
      outputs[0]);
212 213
}

214 215
const char *ConstantOp::attributes_name[attributes_num] = {"value"};

216 217
void ConstantOp::Build(Builder &builder,
                       OperationArgument &argument,
218 219 220 221 222 223 224
                       Attribute value,
                       Type output_type) {
  argument.AddAttribute("value", value);
  argument.output_types.push_back(output_type);
}

void ConstantOp::Verify(const std::vector<ir::OpResult> &inputs,
K
kangguangli 已提交
225 226
                        const std::vector<ir::Type> &outputs,
                        const ir::AttributeMap &attributes) {
227 228 229 230
  IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
  IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
  IR_ENFORCE(attributes.count("value") > 0,
             "Type of attribute: value is not right.");
K
kangguangli 已提交
231 232
}

233 234
Attribute ConstantOp::value() { return operation()->attributes().at("value"); }

235
}  // namespace ir
236 237 238 239 240 241 242 243

IR_DEFINE_EXPLICIT_TYPE_ID(ir::ModuleOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp)