operation.cc 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/core/operation.h"
16
#include "paddle/ir/core/block.h"
17 18
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/program.h"
19
#include "paddle/ir/core/region.h"
20
#include "paddle/ir/core/utils.h"
21 22

namespace ir {
23 24
Operation *Operation::create(OperationArgument &&argument) {
  Operation *op = create(argument.inputs,
25
                         argument.attributes,
26 27 28 29 30 31 32 33
                         argument.output_types,
                         argument.info,
                         argument.regions.size());

  for (size_t index = 0; index < argument.regions.size(); ++index) {
    op->GetRegion(index).TakeBody(std::move(*argument.regions[index]));
  }
  return op;
34 35
}

36 37 38 39
// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
40
                             const AttributeMap &attributes,
41 42 43
                             const std::vector<ir::Type> &output_types,
                             ir::OpInfo op_info,
                             size_t num_regions) {
44 45
  // 0. Verify
  if (op_info) {
46
    op_info.verify(inputs, output_types, attributes);
47
  }
48 49 50 51 52 53 54 55 56 57 58 59 60 61
  // 1. Calculate the required memory size for OpResults + Operation +
  // OpOperands.
  uint32_t num_results = output_types.size();
  uint32_t num_operands = inputs.size();
  uint32_t max_inline_result_num =
      detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
  size_t result_mem_size =
      num_results > max_inline_result_num
          ? sizeof(detail::OpOutlineResultImpl) *
                    (num_results - max_inline_result_num) +
                sizeof(detail::OpInlineResultImpl) * max_inline_result_num
          : sizeof(detail::OpInlineResultImpl) * num_results;
  size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
  size_t op_mem_size = sizeof(Operation);
62 63 64
  size_t region_mem_size = num_regions * sizeof(Region);
  size_t base_size =
      result_mem_size + op_mem_size + operand_mem_size + region_mem_size;
65 66 67 68 69 70 71 72 73 74 75 76 77 78
  // 2. Malloc memory.
  char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8));
  // 3.1. Construct OpResults.
  for (size_t idx = num_results; idx > 0; idx--) {
    if (idx > max_inline_result_num) {
      new (base_ptr)
          detail::OpOutlineResultImpl(output_types[idx - 1], idx - 1);
      base_ptr += sizeof(detail::OpOutlineResultImpl);
    } else {
      new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1);
      base_ptr += sizeof(detail::OpInlineResultImpl);
    }
  }
  // 3.2. Construct Operation.
79
  Operation *op = new (base_ptr)
80
      Operation(attributes, op_info, num_results, num_operands, num_regions);
81 82 83 84 85 86 87 88 89
  base_ptr += sizeof(Operation);
  // 3.3. Construct OpOperands.
  if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
    throw("The address of OpOperandImpl must be divisible by 8.");
  }
  for (size_t idx = 0; idx < num_operands; idx++) {
    new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
    base_ptr += sizeof(detail::OpOperandImpl);
  }
90 91 92 93 94 95 96 97
  // 3.4. Construct Regions
  if (num_regions > 0) {
    op->regions_ = reinterpret_cast<Region *>(base_ptr);
    for (size_t idx = 0; idx < num_regions; idx++) {
      new (base_ptr) Region(op);
      base_ptr += sizeof(Region);
    }
  }
98 99 100 101 102 103
  return op;
}

// Call destructors for OpResults, Operation, and OpOperands in sequence, and
// finally free memory.
void Operation::destroy() {
104 105 106 107 108 109 110
  // Deconstruct Regions.
  if (num_regions_ > 0) {
    for (size_t idx = 0; idx < num_regions_; idx++) {
      regions_[idx].~Region();
    }
  }

111 112 113 114 115 116 117 118 119 120 121 122 123
  // 1. Get aligned_ptr by result_num.
  uint32_t max_inline_result_num =
      detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
  size_t result_mem_size =
      num_results_ > max_inline_result_num
          ? sizeof(detail::OpOutlineResultImpl) *
                    (num_results_ - max_inline_result_num) +
                sizeof(detail::OpInlineResultImpl) * max_inline_result_num
          : sizeof(detail::OpInlineResultImpl) * num_results_;
  char *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;
  // 2.1. Deconstruct OpResult.
  char *base_ptr = aligned_ptr;
  for (size_t idx = num_results_; idx > 0; idx--) {
124 125 126
    // release the uses of this result
    detail::OpOperandImpl *first_use =
        reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
Z
zhangbo9674 已提交
127 128 129 130
    while (first_use != nullptr) {
      first_use->remove_from_ud_chain();
      first_use =
          reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
131
    }
132
    // destory the result
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    if (idx > max_inline_result_num) {
      reinterpret_cast<detail::OpOutlineResultImpl *>(base_ptr)
          ->~OpOutlineResultImpl();
      base_ptr += sizeof(detail::OpOutlineResultImpl);
    } else {
      reinterpret_cast<detail::OpInlineResultImpl *>(base_ptr)
          ->~OpInlineResultImpl();
      base_ptr += sizeof(detail::OpInlineResultImpl);
    }
  }
  // 2.2. Deconstruct Operation.
  if (reinterpret_cast<uintptr_t>(base_ptr) !=
      reinterpret_cast<uintptr_t>(this)) {
    throw("Operation address error");
  }
  reinterpret_cast<Operation *>(base_ptr)->~Operation();
  base_ptr += sizeof(Operation);
C
co63oc 已提交
150
  // 2.3. Deconstruct OpOperand.
151 152 153 154 155 156 157 158 159 160 161
  for (size_t idx = 0; idx < num_operands_; idx++) {
    reinterpret_cast<detail::OpOperandImpl *>(base_ptr)->~OpOperandImpl();
    base_ptr += sizeof(detail::OpOperandImpl);
  }
  // 3. Free memory.
  VLOG(4) << "Destroy an Operation: {ptr = "
          << reinterpret_cast<void *>(aligned_ptr)
          << ", size = " << result_mem_size << "}";
  aligned_free(reinterpret_cast<void *>(aligned_ptr));
}

162 163
IrContext *Operation::ir_context() const { return op_info_.ir_context(); }

164
Operation::Operation(const AttributeMap &attributes,
165 166
                     ir::OpInfo op_info,
                     uint32_t num_results,
167
                     uint32_t num_operands,
168
                     uint32_t num_regions)
169
    : attributes_(attributes),
170 171 172 173
      op_info_(op_info),
      num_results_(num_results),
      num_operands_(num_operands),
      num_regions_(num_regions) {}
174

175
ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
176 177 178 179
  if (index >= num_results_) {
    throw("index exceeds OP output range.");
  }
  uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
180 181 182 183 184 185 186
  const char *ptr =
      (index > max_inline_idx)
          ? reinterpret_cast<const char *>(this) -
                (max_inline_idx + 1) * sizeof(detail::OpInlineResultImpl) -
                (index - max_inline_idx) * sizeof(detail::OpOutlineResultImpl)
          : reinterpret_cast<const char *>(this) -
                (index + 1) * sizeof(detail::OpInlineResultImpl);
187
  if (index > max_inline_idx) {
188 189
    return ir::OpResult(
        reinterpret_cast<const detail::OpOutlineResultImpl *>(ptr));
190
  } else {
191 192
    return ir::OpResult(
        reinterpret_cast<const detail::OpInlineResultImpl *>(ptr));
193 194 195
  }
}

196
ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
197 198 199
  if (index >= num_operands_) {
    throw("index exceeds OP input range.");
  }
200 201 202
  const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
                    (index) * sizeof(detail::OpOperandImpl);
  return ir::OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
203 204
}

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
std::string Operation::print() {
  std::stringstream result;
  result << "{ " << num_results_ << " outputs, " << num_operands_
         << " inputs } : ";
  result << "[ ";
  for (size_t idx = num_results_; idx > 0; idx--) {
    result << GetResultByIndex(idx - 1).impl_ << ", ";
  }
  result << "] = ";
  result << this << "( ";
  for (size_t idx = 0; idx < num_operands_; idx++) {
    result << reinterpret_cast<void *>(reinterpret_cast<char *>(this) +
                                       sizeof(Operation) +
                                       idx * sizeof(detail::OpOperandImpl))
           << ", ";
  }
  result << ")";
  return result.str();
}

225
std::string Operation::op_name() const { return op_info_.name(); }
226

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
Region *Operation::GetParentRegion() const {
  return parent_ ? parent_->GetParentRegion() : nullptr;
}

Operation *Operation::GetParentOp() const {
  return parent_ ? parent_->GetParentOp() : nullptr;
}

Program *Operation::GetParentProgram() {
  Operation *op = this;
  while (Operation *parent_op = op->GetParentOp()) {
    op = parent_op;
  }
  ModuleOp module_op = op->dyn_cast<ModuleOp>();
  return module_op ? module_op.program() : nullptr;
}

244 245 246 247 248
Region &Operation::GetRegion(unsigned index) {
  assert(index < num_regions_ && "invalid region index");
  return regions_[index];
}

249
}  // namespace ir