operation.cc 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include <ostream>

17
#include "paddle/ir/core/block.h"
18
#include "paddle/ir/core/dialect.h"
19
#include "paddle/ir/core/enforce.h"
20 21
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
22
#include "paddle/ir/core/program.h"
23
#include "paddle/ir/core/region.h"
24
#include "paddle/ir/core/utils.h"
25
#include "paddle/ir/core/value_impl.h"
26 27

namespace ir {
28 29
Operation *Operation::Create(OperationArgument &&argument) {
  Operation *op = Create(argument.inputs,
30
                         argument.attributes,
31 32 33 34 35
                         argument.output_types,
                         argument.info,
                         argument.regions.size());

  for (size_t index = 0; index < argument.regions.size(); ++index) {
36
    op->region(index).TakeBody(std::move(*argument.regions[index]));
37 38
  }
  return op;
39 40
}

41 42
// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
43
// OpInlineResult, Operation, operand.
44
Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
45
                             const AttributeMap &attributes,
46 47 48
                             const std::vector<ir::Type> &output_types,
                             ir::OpInfo op_info,
                             size_t num_regions) {
49 50 51 52 53 54 55 56 57 58 59 60 61 62
  // 1. Calculate the required memory size for OpResults + Operation +
  // OpOperands.
  uint32_t num_results = output_types.size();
  uint32_t num_operands = inputs.size();
  uint32_t max_inline_result_num =
      detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
  size_t result_mem_size =
      num_results > max_inline_result_num
          ? sizeof(detail::OpOutlineResultImpl) *
                    (num_results - max_inline_result_num) +
                sizeof(detail::OpInlineResultImpl) * max_inline_result_num
          : sizeof(detail::OpInlineResultImpl) * num_results;
  size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
  size_t op_mem_size = sizeof(Operation);
63 64 65
  size_t region_mem_size = num_regions * sizeof(Region);
  size_t base_size =
      result_mem_size + op_mem_size + operand_mem_size + region_mem_size;
66 67 68 69 70 71 72 73 74 75 76 77 78 79
  // 2. Malloc memory.
  char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8));
  // 3.1. Construct OpResults.
  for (size_t idx = num_results; idx > 0; idx--) {
    if (idx > max_inline_result_num) {
      new (base_ptr)
          detail::OpOutlineResultImpl(output_types[idx - 1], idx - 1);
      base_ptr += sizeof(detail::OpOutlineResultImpl);
    } else {
      new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1);
      base_ptr += sizeof(detail::OpInlineResultImpl);
    }
  }
  // 3.2. Construct Operation.
80
  Operation *op = new (base_ptr)
81
      Operation(attributes, op_info, num_results, num_operands, num_regions);
82 83 84
  base_ptr += sizeof(Operation);
  // 3.3. Construct OpOperands.
  if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
85
    IR_THROW("The address of OpOperandImpl must be divisible by 8.");
86 87 88 89 90
  }
  for (size_t idx = 0; idx < num_operands; idx++) {
    new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
    base_ptr += sizeof(detail::OpOperandImpl);
  }
91 92 93 94 95 96 97 98
  // 3.4. Construct Regions
  if (num_regions > 0) {
    op->regions_ = reinterpret_cast<Region *>(base_ptr);
    for (size_t idx = 0; idx < num_regions; idx++) {
      new (base_ptr) Region(op);
      base_ptr += sizeof(Region);
    }
  }
99 100 101 102 103

  // 0. Verify
  if (op_info) {
    op_info.Verify(op);
  }
104 105 106
  return op;
}

107 108
// Call destructors for Region , OpResults, Operation, and OpOperands in
// sequence, and finally free memory.
109
void Operation::Destroy() {
110
  VLOG(6) << "Destroy Operation [" << name() << "] ...";
111
  // 1. Deconstruct Regions.
112 113 114 115 116 117
  if (num_regions_ > 0) {
    for (size_t idx = 0; idx < num_regions_; idx++) {
      regions_[idx].~Region();
    }
  }

118 119 120
  // 2. Deconstruct Result.
  for (size_t idx = 0; idx < num_results_; ++idx) {
    detail::OpResultImpl *impl = result(idx).impl();
121 122
    IR_ENFORCE(impl->use_empty(),
               name() + " operation destroyed but still has uses.");
123 124 125 126 127 128 129 130 131 132 133 134
    if (detail::OpOutlineResultImpl::classof(*impl)) {
      static_cast<detail::OpOutlineResultImpl *>(impl)->~OpOutlineResultImpl();
    } else {
      static_cast<detail::OpInlineResultImpl *>(impl)->~OpInlineResultImpl();
    }
  }

  // 3. Deconstruct Operation.
  this->~Operation();

  // 4. Deconstruct OpOperand.
  for (size_t idx = 0; idx < num_operands_; idx++) {
135
    operand(idx).impl()->~OpOperandImpl();
136 137
  }
  // 5. Free memory.
138 139 140 141 142 143 144 145
  uint32_t max_inline_result_num =
      detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
  size_t result_mem_size =
      num_results_ > max_inline_result_num
          ? sizeof(detail::OpOutlineResultImpl) *
                    (num_results_ - max_inline_result_num) +
                sizeof(detail::OpInlineResultImpl) * max_inline_result_num
          : sizeof(detail::OpInlineResultImpl) * num_results_;
146 147
  void *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;

148 149
  VLOG(6) << "Destroy Operation [" << name() << "]: {ptr = " << aligned_ptr
          << ", size = " << result_mem_size << "} done.";
150
  aligned_free(aligned_ptr);
151 152
}

153
IrContext *Operation::ir_context() const { return info_.ir_context(); }
154

155 156
Dialect *Operation::dialect() const { return info_.dialect(); }

157
Operation::Operation(const AttributeMap &attributes,
158 159
                     ir::OpInfo op_info,
                     uint32_t num_results,
160
                     uint32_t num_operands,
161
                     uint32_t num_regions)
162
    : attributes_(attributes),
163
      info_(op_info),
164 165 166
      num_results_(num_results),
      num_operands_(num_operands),
      num_regions_(num_regions) {}
167

168
ir::OpResult Operation::result(uint32_t index) const {
169
  if (index >= num_results_) {
170
    IR_THROW("index exceeds OP output range.");
171 172
  }
  uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
173 174 175 176 177 178 179
  const char *ptr =
      (index > max_inline_idx)
          ? reinterpret_cast<const char *>(this) -
                (max_inline_idx + 1) * sizeof(detail::OpInlineResultImpl) -
                (index - max_inline_idx) * sizeof(detail::OpOutlineResultImpl)
          : reinterpret_cast<const char *>(this) -
                (index + 1) * sizeof(detail::OpInlineResultImpl);
180
  if (index > max_inline_idx) {
181 182
    return ir::OpResult(
        reinterpret_cast<const detail::OpOutlineResultImpl *>(ptr));
183
  } else {
184 185
    return ir::OpResult(
        reinterpret_cast<const detail::OpInlineResultImpl *>(ptr));
186 187 188
  }
}

189
OpOperand Operation::operand(uint32_t index) const {
190
  if (index >= num_operands_) {
191
    IR_THROW("index exceeds OP input range.");
192
  }
193 194
  const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
                    (index) * sizeof(detail::OpOperandImpl);
195 196 197
  return OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
}

198 199
Value Operation::operand_source(uint32_t index) const {
  OpOperand val = operand(index);
200
  return val ? val.source() : Value();
201 202
}

203 204 205
std::string Operation::name() const {
  auto p_name = info_.name();
  return p_name ? p_name : "";
206 207
}

208 209 210 211 212
Attribute Operation::attribute(const std::string &key) const {
  IR_ENFORCE(HasAttribute(key), "operation(%s): no attribute %s", name(), key);
  return attributes_.at(key);
}

213
Region *Operation::GetParentRegion() {
214
  return parent_ ? parent_->GetParent() : nullptr;
215 216 217 218 219 220
}

Operation *Operation::GetParentOp() const {
  return parent_ ? parent_->GetParentOp() : nullptr;
}

221 222
const Program *Operation::GetParentProgram() const {
  Operation *op = const_cast<Operation *>(this);
223 224 225 226 227 228 229
  while (Operation *parent_op = op->GetParentOp()) {
    op = parent_op;
  }
  ModuleOp module_op = op->dyn_cast<ModuleOp>();
  return module_op ? module_op.program() : nullptr;
}

230
Region &Operation::region(unsigned index) {
231 232 233 234
  assert(index < num_regions_ && "invalid region index");
  return regions_[index];
}

235 236 237 238 239
const Region &Operation::region(unsigned index) const {
  assert(index < num_regions_ && "invalid region index");
  return regions_[index];
}

240 241 242 243 244
void Operation::SetParent(Block *parent, const Block::iterator &position) {
  parent_ = parent;
  position_ = position;
}

245 246 247 248 249 250 251 252
void Operation::ReplaceAllUsesWith(const std::vector<Value> &values) {
  IR_ENFORCE(num_results_ == values.size(),
             "the num of result should be the same.");
  for (uint32_t i = 0; i < num_results_; ++i) {
    result(i).ReplaceAllUsesWith(values[i]);
  }
}

253 254 255 256 257 258 259 260
void Operation::ReplaceAllUsesWith(const std::vector<OpResult> &op_results) {
  IR_ENFORCE(num_results_ == op_results.size(),
             "the num of result should be the same.");
  for (uint32_t i = 0; i < num_results_; ++i) {
    result(i).ReplaceAllUsesWith(op_results[i]);
  }
}

261 262 263 264 265 266
void Operation::Verify() {
  if (info_) {
    info_.Verify(this);
  }
}

267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
std::vector<OpOperand> Operation::operands() const {
  std::vector<OpOperand> res;
  for (uint32_t i = 0; i < num_operands(); ++i) {
    res.push_back(operand(i));
  }
  return res;
}

std::vector<OpResult> Operation::results() const {
  std::vector<OpResult> res;
  for (uint32_t i = 0; i < num_results(); ++i) {
    res.push_back(result(i));
  }
  return res;
}

283
}  // namespace ir