value.cc 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/core/value.h"
16
#include "paddle/ir/core/enforce.h"
17
#include "paddle/ir/core/operation.h"
18
#include "paddle/ir/core/value_impl.h"
19

20 21 22 23 24 25 26 27 28 29 30
#define CHECK_NULL_IMPL(class_name, func_name)                  \
  IR_ENFORCE(impl_,                                             \
             "impl_ pointer is null when call func:" #func_name \
             " , in class: " #class_name ".")

#define CHECK_OPOPEREND_NULL_IMPL(func_name) \
  CHECK_NULL_IMPL(OpOpernad, func_name)

#define CHECK_VALUE_NULL_IMPL(func_name) CHECK_NULL_IMPL(Value, func_name)

#define CHECK_OPRESULT_NULL_IMPL(func_name) CHECK_NULL_IMPL(OpResult, func_name)
31
namespace ir {
32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
// Operand
OpOperand::OpOperand(const detail::OpOperandImpl *impl)
    : impl_(const_cast<detail::OpOperandImpl *>(impl)) {}

OpOperand &OpOperand::operator=(const OpOperand &rhs) {
  if (this == &rhs) return *this;
  impl_ = rhs.impl_;
  return *this;
}

OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
  if (this->impl_ == impl) return *this;
  impl_ = const_cast<detail::OpOperandImpl *>(impl);
  return *this;
}
K
kangguangli 已提交
48
OpOperand::operator bool() const { return impl_ && impl_->source(); }
49

50 51 52 53
OpOperand OpOperand::next_use() const {
  CHECK_OPOPEREND_NULL_IMPL(next_use);
  return impl_->next_use();
}
54

55 56 57 58
Value OpOperand::source() const {
  CHECK_OPOPEREND_NULL_IMPL(source);
  return impl_->source();
}
59

60 61
Type OpOperand::type() const { return source().type(); }

62 63 64 65
void OpOperand::set_source(Value value) {
  CHECK_OPOPEREND_NULL_IMPL(set_source);
  impl_->set_source(value);
}
66

67 68 69 70
Operation *OpOperand::owner() const {
  CHECK_OPOPEREND_NULL_IMPL(owner);
  return impl_->owner();
}
71

72 73 74
void OpOperand::RemoveFromUdChain() {
  CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain);
  return impl_->RemoveFromUdChain();
75
}
76

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
// Value
Value::Value(const detail::ValueImpl *impl)
    : impl_(const_cast<detail::ValueImpl *>(impl)) {}

bool Value::operator==(const Value &other) const {
  return impl_ == other.impl_;
}

bool Value::operator!=(const Value &other) const {
  return impl_ != other.impl_;
}

bool Value::operator!() const { return impl_ == nullptr; }

Value::operator bool() const { return impl_; }

93 94 95 96
ir::Type Value::type() const {
  CHECK_VALUE_NULL_IMPL(type);
  return impl_->type();
}
97

98 99 100 101
void Value::set_type(ir::Type type) {
  CHECK_VALUE_NULL_IMPL(set_type);
  impl_->set_type(type);
}
102 103 104 105 106 107

Operation *Value::GetDefiningOp() const {
  if (auto result = dyn_cast<OpResult>()) return result.owner();
  return nullptr;
}

108 109 110 111
std::string Value::PrintUdChain() {
  CHECK_VALUE_NULL_IMPL(PrintUdChain);
  return impl()->PrintUdChain();
}
112

113 114 115
Value::UseIterator Value::use_begin() const {
  return ir::OpOperand(first_use());
}
116

117
Value::UseIterator Value::use_end() const { return Value::UseIterator(); }
118

119 120 121 122
OpOperand Value::first_use() const {
  CHECK_VALUE_NULL_IMPL(first_use);
  return impl_->first_use();
}
123

124 125
bool Value::use_empty() const { return !first_use(); }

126 127 128 129
bool Value::HasOneUse() const {
  CHECK_VALUE_NULL_IMPL(HasOneUse);
  return impl_->HasOneUse();
}
130

131 132 133
void Value::ReplaceUsesWithIf(
    Value new_value,
    const std::function<bool(OpOperand)> &should_replace) const {
134
  for (auto it = use_begin(); it != use_end();) {
135 136
    if (should_replace(*it)) {
      (it++)->set_source(new_value);
137 138 139 140
    }
  }
}

141
void Value::ReplaceAllUsesWith(Value new_value) const {
142
  for (auto it = use_begin(); it != use_end();) {
143 144 145 146
    (it++)->set_source(new_value);
  }
}

147 148
// OpResult
bool OpResult::classof(Value value) {
149
  return value && ir::isa<detail::OpResultImpl>(value.impl());
150 151
}

152 153 154 155
Operation *OpResult::owner() const {
  CHECK_OPRESULT_NULL_IMPL(owner);
  return impl()->owner();
}
156

157 158 159 160
uint32_t OpResult::GetResultIndex() const {
  CHECK_OPRESULT_NULL_IMPL(GetResultIndex);
  return impl()->GetResultIndex();
}
161 162 163 164 165

detail::OpResultImpl *OpResult::impl() const {
  return reinterpret_cast<detail::OpResultImpl *>(impl_);
}

X
xiaoguoguo626807 已提交
166 167 168 169 170 171 172 173 174
bool OpResult::operator==(const OpResult &other) const {
  return impl_ == other.impl_;
}

detail::ValueImpl *OpResult::value_impl() const {
  IR_ENFORCE(impl_, "Can't use value_impl() interface while value is null.");
  return impl_;
}

175 176 177 178 179 180 181 182 183 184 185 186
uint32_t OpResult::GetValidInlineIndex(uint32_t index) {
  uint32_t max_inline_index =
      ir::detail::OpResultImpl::GetMaxInlineResultIndex();
  return index <= max_inline_index ? index : max_inline_index;
}

// details
namespace detail {
ir::Operation *OpOperandImpl::owner() const { return owner_; }

ir::detail::OpOperandImpl *OpOperandImpl::next_use() { return next_use_; }

187 188
ir::Value OpOperandImpl::source() const { return source_; }

189 190 191 192 193 194 195 196
void OpOperandImpl::set_source(Value source) {
  RemoveFromUdChain();
  if (!source) {
    return;
  }
  source_ = source;
  InsertToUdChain();
}
197

198 199
OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
    : source_(source), owner_(owner) {
K
kangguangli 已提交
200 201 202
  if (!source) {
    return;
  }
203 204 205 206 207 208
  InsertToUdChain();
}

void OpOperandImpl::InsertToUdChain() {
  prev_use_addr_ = source_.impl()->first_use_addr();
  next_use_ = source_.impl()->first_use();
209 210 211
  if (next_use_) {
    next_use_->prev_use_addr_ = &next_use_;
  }
212
  source_.impl()->set_first_use(this);
213 214
}

215
void OpOperandImpl::RemoveFromUdChain() {
K
kangguangli 已提交
216
  if (!source_) return;
217 218 219
  if (!prev_use_addr_) return;
  if (prev_use_addr_ == source_.impl()->first_use_addr()) {
    /// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits
220
    /// storage index information, so need to be updated using the set_first_use
221
    /// method here.
222
    source_.impl()->set_first_use(next_use_);
223 224 225 226 227 228
  } else {
    *prev_use_addr_ = next_use_;
  }
  if (next_use_) {
    next_use_->prev_use_addr_ = prev_use_addr_;
  }
Z
zhangbo9674 已提交
229 230
  next_use_ = nullptr;
  prev_use_addr_ = nullptr;
231
  source_ = nullptr;
232 233
}

234
OpOperandImpl::~OpOperandImpl() { RemoveFromUdChain(); }
235 236 237 238 239 240 241 242 243

uint32_t ValueImpl::index() const {
  uint32_t index =
      reinterpret_cast<uintptr_t>(first_use_offseted_by_index_) & 0x07;
  if (index < 6) return index;
  return reinterpret_cast<OpOutlineResultImpl *>(const_cast<ValueImpl *>(this))
      ->GetResultIndex();
}

244
std::string ValueImpl::PrintUdChain() {
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
  std::stringstream result;
  result << "Value[" << this << "] -> ";
  OpOperandImpl *tmp = first_use();
  if (tmp) {
    result << "OpOperand[" << reinterpret_cast<void *>(tmp) << "] -> ";
    while (tmp->next_use() != nullptr) {
      result << "OpOperand[" << reinterpret_cast<void *>(tmp->next_use())
             << "] -> ";
      tmp = tmp->next_use();
    }
  }
  result << "nullptr";
  return result.str();
}

uint32_t OpResultImpl::GetResultIndex() const {
  if (const auto *outline_result = ir::dyn_cast<OpOutlineResultImpl>(this)) {
    return outline_result->GetResultIndex();
  }
  return ir::dyn_cast<OpInlineResultImpl>(this)->GetResultIndex();
}

267 268 269 270 271
OpResultImpl::~OpResultImpl() {
  assert(use_empty() &&
         owner()->name() + " operation destroyed but still has uses.");
}

272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
ir::Operation *OpResultImpl::owner() const {
  // For inline result, pointer offset index to obtain the address of op.
  if (const auto *result = ir::dyn_cast<OpInlineResultImpl>(this)) {
    result += result->GetResultIndex() + 1;
    return reinterpret_cast<Operation *>(
        const_cast<OpInlineResultImpl *>(result));
  }
  // For outline result, pointer offset outline_index to obtain the address of
  // maximum inline result.
  const OpOutlineResultImpl *outline_result =
      (const OpOutlineResultImpl *)(this);
  outline_result +=
      (outline_result->outline_index_ - GetMaxInlineResultIndex());
  // The offset of the maximum inline result distance op is
  // GetMaxInlineResultIndex.
  const auto *inline_result =
      reinterpret_cast<const OpInlineResultImpl *>(outline_result);
  inline_result += (GetMaxInlineResultIndex() + 1);
  return reinterpret_cast<Operation *>(
      const_cast<OpInlineResultImpl *>(inline_result));
}
}  // namespace detail
}  // namespace ir