未验证 提交 fd97d7d1 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Value system && Operation (#51992)

* add Value OpResult OpOperand class

* add Value OpResult OpOperand class

* fix bug

* fix bug

* add utils

* refine code

* add ptr offset and reset method

* add value impl

* fix bug

* refine comment of ValueImpl

* refine code of OpResult

* refine code of Value

* add some comment

* fix cpu compile bug

* refine code

* add op

* add method for op & test value

* refine unittest

* refine code by comment

* refine code

* refine code

* refine code

* refine code
上级 8cbeefea
......@@ -16,6 +16,7 @@
#include "paddle/ir/attribute.h"
#include "paddle/ir/builtin_attribute_storage.h"
#include "paddle/ir/utils.h"
namespace ir {
///
......@@ -82,15 +83,11 @@ class DictionaryAttribute : public ir::Attribute {
} // namespace ir
namespace std {
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
template <>
struct hash<ir::NamedAttribute> {
std::size_t operator()(const ir::NamedAttribute &obj) const {
return hash_combine(std::hash<ir::Attribute>()(obj.name_),
std::hash<ir::Attribute>()(obj.value_));
return ir::hash_combine(std::hash<ir::Attribute>()(obj.name_),
std::hash<ir::Attribute>()(obj.value_));
}
};
} // namespace std
......@@ -14,6 +14,7 @@
#include "paddle/ir/builtin_attribute_storage.h"
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/utils.h"
namespace ir {
......@@ -32,7 +33,7 @@ DictionaryAttributeStorage::DictionaryAttributeStorage(const ParamKey &key) {
std::size_t DictionaryAttributeStorage::HashValue(const ParamKey &key) {
std::size_t hash_value = key.size();
for (auto iter = key.begin(); iter != key.end(); ++iter) {
hash_value = hash_combine(
hash_value = ir::hash_combine(
hash_value,
std::hash<NamedAttribute>()(NamedAttribute(iter->first, iter->second)));
}
......
......@@ -83,10 +83,6 @@ struct DictionaryAttributeStorage : public AttributeStorage {
uint32_t size() const { return size_; }
private:
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
NamedAttribute *data_;
uint32_t size_;
};
......
......@@ -17,6 +17,7 @@
#include <type_traits>
#include "paddle/ir/type.h"
#include "paddle/ir/utils.h"
namespace std {
///
......@@ -109,20 +110,22 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
std::size_t hash_value = 0;
// hash dtype
hash_value =
hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims
hash_value = hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key)));
// hash layout
hash_value =
hash_combine(hash_value,
std::hash<std::underlying_type<DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>(
std::get<2>(key))));
ir::hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key)));
// hash layout
hash_value = ir::hash_combine(
hash_value,
std::hash<std::underlying_type<DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>(
std::get<2>(key))));
// hash lod
hash_value = hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key)));
hash_value =
ir::hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key)));
// hash offset
hash_value =
hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
return hash_value;
}
......@@ -146,11 +149,6 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
DataLayout layout_;
LoD lod_;
size_t offset_;
private:
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/operation.h"
namespace ir {
class OpBase {
public:
Operation *operation() { return operation_; }
explicit operator bool() { return operation() != nullptr; }
operator Operation *() const { return operation_; }
Operation *operator->() const { return operation_; }
protected:
explicit OpBase(Operation *operation) : operation_(operation) {}
private:
Operation *operation_;
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/operation.h"
#include "paddle/ir/utils.h"
namespace ir {
// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
ir::DictionaryAttribute attribute) {
// 1. Calculate the required memory size for OpResults + Operation +
// OpOperands.
uint32_t num_results = output_types.size();
uint32_t num_operands = inputs.size();
uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
size_t result_mem_size =
num_results > max_inline_result_num
? sizeof(detail::OpOutlineResultImpl) *
(num_results - max_inline_result_num) +
sizeof(detail::OpInlineResultImpl) * max_inline_result_num
: sizeof(detail::OpInlineResultImpl) * num_results;
size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
size_t op_mem_size = sizeof(Operation);
size_t base_size = result_mem_size + op_mem_size + operand_mem_size;
// 2. Malloc memory.
char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8));
// 3.1. Construct OpResults.
for (size_t idx = num_results; idx > 0; idx--) {
if (idx > max_inline_result_num) {
new (base_ptr)
detail::OpOutlineResultImpl(output_types[idx - 1], idx - 1);
base_ptr += sizeof(detail::OpOutlineResultImpl);
} else {
new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1);
base_ptr += sizeof(detail::OpInlineResultImpl);
}
}
// 3.2. Construct Operation.
Operation *op =
new (base_ptr) Operation(num_results, num_operands, attribute);
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
throw("The address of OpOperandImpl must be divisible by 8.");
}
for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
base_ptr += sizeof(detail::OpOperandImpl);
}
VLOG(4) << "Construct an Operation: " << op->print();
return op;
}
// Call destructors for OpResults, Operation, and OpOperands in sequence, and
// finally free memory.
void Operation::destroy() {
// 1. Get aligned_ptr by result_num.
uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
size_t result_mem_size =
num_results_ > max_inline_result_num
? sizeof(detail::OpOutlineResultImpl) *
(num_results_ - max_inline_result_num) +
sizeof(detail::OpInlineResultImpl) * max_inline_result_num
: sizeof(detail::OpInlineResultImpl) * num_results_;
char *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;
// 2.1. Deconstruct OpResult.
char *base_ptr = aligned_ptr;
for (size_t idx = num_results_; idx > 0; idx--) {
if (!reinterpret_cast<detail::OpResultImpl *>(base_ptr)->use_empty()) {
throw("Cannot destroy a value that still has uses!");
}
if (idx > max_inline_result_num) {
reinterpret_cast<detail::OpOutlineResultImpl *>(base_ptr)
->~OpOutlineResultImpl();
base_ptr += sizeof(detail::OpOutlineResultImpl);
} else {
reinterpret_cast<detail::OpInlineResultImpl *>(base_ptr)
->~OpInlineResultImpl();
base_ptr += sizeof(detail::OpInlineResultImpl);
}
}
// 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) {
throw("Operation address error");
}
reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation);
// 2.3. Deconstruct OpOpOerand.
for (size_t idx = 0; idx < num_operands_; idx++) {
reinterpret_cast<detail::OpOperandImpl *>(base_ptr)->~OpOperandImpl();
base_ptr += sizeof(detail::OpOperandImpl);
}
// 3. Free memory.
VLOG(4) << "Destroy an Operation: {ptr = "
<< reinterpret_cast<void *>(aligned_ptr)
<< ", size = " << result_mem_size << "}";
aligned_free(reinterpret_cast<void *>(aligned_ptr));
}
Operation::Operation(uint32_t num_results,
uint32_t num_operands,
ir::DictionaryAttribute attribute) {
if (!attribute) {
throw("unexpected null attribute dictionary");
}
num_results_ = num_results;
num_operands_ = num_operands;
attribute_ = attribute;
}
ir::OpResult Operation::GetResultByIndex(uint32_t index) {
if (index >= num_results_) {
throw("index exceeds OP output range.");
}
uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
char *ptr = nullptr;
if (index > max_inline_idx) {
ptr = reinterpret_cast<char *>(this) -
(max_inline_idx + 1) * sizeof(detail::OpInlineResultImpl) -
(index - max_inline_idx) * sizeof(detail::OpOutlineResultImpl);
} else {
ptr = reinterpret_cast<char *>(this) -
(index + 1) * sizeof(detail::OpInlineResultImpl);
}
if (index > max_inline_idx) {
detail::OpOutlineResultImpl *result_impl_ptr =
reinterpret_cast<detail::OpOutlineResultImpl *>(ptr);
return ir::OpResult(result_impl_ptr);
} else {
detail::OpInlineResultImpl *result_impl_ptr =
reinterpret_cast<detail::OpInlineResultImpl *>(ptr);
return ir::OpResult(result_impl_ptr);
}
}
std::string Operation::print() {
std::stringstream result;
result << "{ " << num_results_ << " outputs, " << num_operands_
<< " inputs } : ";
result << "[ ";
for (size_t idx = num_results_; idx > 0; idx--) {
result << GetResultByIndex(idx - 1).impl_ << ", ";
}
result << "] = ";
result << this << "( ";
for (size_t idx = 0; idx < num_operands_; idx++) {
result << reinterpret_cast<void *>(reinterpret_cast<char *>(this) +
sizeof(Operation) +
idx * sizeof(detail::OpOperandImpl))
<< ", ";
}
result << ")";
return result.str();
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/type.h"
#include "paddle/ir/value_impl.h"
namespace ir {
class alignas(8) Operation final {
public:
///
/// \brief Malloc memory and construct objects in the following order:
/// OpResultImpls|Operation|OpOperandImpls.
///
static Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
ir::DictionaryAttribute attribute);
void destroy();
ir::OpResult GetResultByIndex(uint32_t index);
std::string print();
ir::DictionaryAttribute attribute() { return attribute_; }
uint32_t num_results() { return num_results_; }
uint32_t num_operands() { return num_operands_; }
private:
Operation(uint32_t num_results,
uint32_t num_operands,
ir::DictionaryAttribute attribute);
ir::DictionaryAttribute attribute_;
uint32_t num_results_ = 0;
uint32_t num_operands_ = 0;
};
} // namespace ir
cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest)
cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS new_ir gtest)
cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS new_ir gtest)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/attribute.h"
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/ir_context.h"
#include "paddle/ir/operation.h"
// This unittest is used to test the construction interfaces of value class and
// operation. The constructed test scenario is: a = OP1(); b = OP2(); c = OP3(a,
// b); d, e, f, g, h, i, j = OP4(a, c);
ir::DictionaryAttribute CreateAttribute(std::string attribute_name,
std::string attribute) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::StrAttribute attr_name = ir::StrAttribute::get(ctx, attribute_name);
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attribute);
std::map<ir::StrAttribute, ir::Attribute> named_attr;
named_attr.insert(
std::pair<ir::StrAttribute, ir::Attribute>(attr_name, attr_value));
return ir::DictionaryAttribute::get(ctx, named_attr);
}
TEST(value_test, value_test) {
ir::IrContext *ctx = ir::IrContext::Instance();
// 1. Construct OP1: a = OP1()
std::vector<ir::OpResult> op1_inputs = {};
std::vector<ir::Type> op1_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op1 = ir::Operation::create(
op1_inputs, op1_output_types, CreateAttribute("op1_name", "op1_attr"));
std::cout << op1->print() << std::endl;
// 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op2 = ir::Operation::create(
op2_inputs, op2_output_types, CreateAttribute("op2_name", "op2_attr"));
std::cout << op2->print() << std::endl;
// 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
op2->GetResultByIndex(0)};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 = ir::Operation::create(
op3_inputs, op3_output_types, CreateAttribute("op3_name", "op3_attr"));
std::cout << op3->print() << std::endl;
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
op3->GetResultByIndex(0)};
std::vector<ir::Type> op4_output_types;
for (size_t i = 0; i < 7; i++) {
op4_output_types.push_back(ir::Float32Type::get(ctx));
}
ir::Operation *op4 = ir::Operation::create(
op4_inputs, op4_output_types, CreateAttribute("op4_name", "op4_attr"));
std::cout << op4->print() << std::endl;
// Test 1:
EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1);
EXPECT_EQ(op2->GetResultByIndex(0).GetDefiningOp(), op2);
EXPECT_EQ(op3->GetResultByIndex(0).GetDefiningOp(), op3);
EXPECT_EQ(op4->GetResultByIndex(6).GetDefiningOp(), op4);
// Test 2: op1_first_output -> op4_first_input
ir::OpResult op1_first_output = op1->GetResultByIndex(0);
ir::detail::OpOperandImpl *op4_first_input =
reinterpret_cast<ir::detail::OpOperandImpl *>(
reinterpret_cast<uintptr_t>(op4) + sizeof(ir::Operation));
EXPECT_EQ(static_cast<ir::Value>(op1_first_output).impl()->first_use(),
op4_first_input);
ir::detail::OpOperandImpl *op3_first_input =
reinterpret_cast<ir::detail::OpOperandImpl *>(
reinterpret_cast<uintptr_t>(op3) + sizeof(ir::Operation));
EXPECT_EQ(op4_first_input->next_use(), op3_first_input);
EXPECT_EQ(op3_first_input->next_use(), nullptr);
// destroy
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op4->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op3->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op2->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op1->destroy();
}
......@@ -21,6 +21,7 @@
#include "paddle/ir/ir_context.h"
#include "paddle/ir/type.h"
#include "paddle/ir/type_base.h"
#include "paddle/ir/utils.h"
TEST(type_test, type_id) {
// Define two empty classes, just for testing.
......@@ -172,8 +173,8 @@ struct IntegerTypeStorage : public ir::TypeStorage {
using ParamKey = std::pair<unsigned, unsigned>;
static std::size_t HashValue(const ParamKey &key) {
return hash_combine(std::hash<unsigned>()(std::get<0>(key)),
std::hash<unsigned>()(std::get<1>(key)));
return ir::hash_combine(std::hash<unsigned>()(std::get<0>(key)),
std::hash<unsigned>()(std::get<1>(key)));
}
bool operator==(const ParamKey &key) const {
......@@ -188,11 +189,6 @@ struct IntegerTypeStorage : public ir::TypeStorage {
unsigned width_ : 30;
unsigned signedness_ : 2;
private:
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
};
// Customize a parameterized type: IntegerType, storage type is
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/utils.h"
namespace ir {
std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
}
void *aligned_malloc(size_t size, size_t alignment) {
assert(alignment >= sizeof(void *) && (alignment & (alignment - 1)) == 0);
size = (size + alignment - 1) / alignment * alignment;
#if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200112L
void *aligned_mem = nullptr;
if (posix_memalign(&aligned_mem, alignment, size) != 0) {
aligned_mem = nullptr;
}
return aligned_mem;
#elif defined(_WIN32)
return _aligned_malloc(size, alignment);
#else
void *mem = malloc(size + alignment);
if (mem == nullptr) {
return nullptr;
}
size_t adjust = alignment - reinterpret_cast<uint64_t>(mem) % alignment;
void *aligned_mem = reinterpret_cast<char *>(mem) + adjust;
*(reinterpret_cast<void **>(aligned_mem) - 1) = mem;
assert(reinterpret_cast<uint64_t>(aligned_mem) % alignment == 0);
return aligned_mem;
#endif
}
void aligned_free(void *mem_ptr) {
#if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200112L
free(mem_ptr);
#elif defined(_WIN32)
_aligned_free(mem_ptr);
#else
if (mem_ptr) {
free(*(reinterpret_cast<void **>(mem_ptr) - 1));
}
#endif
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstdint>
#include <cstdlib>
namespace ir {
std::size_t hash_combine(std::size_t lhs, std::size_t rhs);
void *aligned_malloc(size_t size, size_t alignment);
void aligned_free(void *mem_ptr);
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/value.h"
#include "paddle/ir/value_impl.h"
namespace ir {
// Operand
OpOperand::OpOperand(const detail::OpOperandImpl *impl)
: impl_(const_cast<detail::OpOperandImpl *>(impl)) {}
OpOperand &OpOperand::operator=(const OpOperand &rhs) {
if (this == &rhs) return *this;
impl_ = rhs.impl_;
return *this;
}
OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
if (this->impl_ == impl) return *this;
impl_ = const_cast<detail::OpOperandImpl *>(impl);
return *this;
}
bool OpOperand::operator==(OpOperand other) const {
return impl_ == other.impl_;
}
bool OpOperand::operator!=(OpOperand other) const {
return impl_ != other.impl_;
}
bool OpOperand::operator!() const { return impl_ == nullptr; }
OpOperand::operator bool() const { return impl_; }
detail::OpOperandImpl *OpOperand::impl() const { return impl_; }
// Value
Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {}
bool Value::operator==(const Value &other) const {
return impl_ == other.impl_;
}
bool Value::operator!=(const Value &other) const {
return impl_ != other.impl_;
}
bool Value::operator!() const { return impl_ == nullptr; }
Value::operator bool() const { return impl_; }
detail::ValueImpl *Value::impl() const { return impl_; }
ir::Type Value::type() const { return impl_->type(); }
void Value::SetType(ir::Type type) { impl_->SetType(type); }
Operation *Value::GetDefiningOp() const {
if (auto result = dyn_cast<OpResult>()) return result.owner();
return nullptr;
}
std::string Value::print_ud_chain() { return impl_->print_ud_chain(); }
// OpResult
bool OpResult::classof(Value value) {
return ir::isa<detail::OpResultImpl>(value.impl());
}
Operation *OpResult::owner() const { return impl()->owner(); }
uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); }
detail::OpResultImpl *OpResult::impl() const {
return reinterpret_cast<detail::OpResultImpl *>(impl_);
}
uint32_t OpResult::GetValidInlineIndex(uint32_t index) {
uint32_t max_inline_index =
ir::detail::OpResultImpl::GetMaxInlineResultIndex();
return index <= max_inline_index ? index : max_inline_index;
}
// details
namespace detail {
ir::Operation *OpOperandImpl::owner() const { return owner_; }
ir::detail::OpOperandImpl *OpOperandImpl::next_use() { return next_use_; }
OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
: source_(source), owner_(owner) {
prev_use_addr_ = source.impl()->first_use_addr();
next_use_ = source.impl()->first_use();
if (next_use_) {
next_use_->prev_use_addr_ = &next_use_;
}
source.impl()->SetFirstUse(this);
}
void OpOperandImpl::remove_from_ud_chain() {
if (!prev_use_addr_) return;
if (prev_use_addr_ == source_.impl()->first_use_addr()) {
/// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits
/// storage index information, so need to be updated using the SetFirstUse
/// method here.
source_.impl()->SetFirstUse(next_use_);
} else {
*prev_use_addr_ = next_use_;
}
if (next_use_) {
next_use_->prev_use_addr_ = prev_use_addr_;
}
}
OpOperandImpl::~OpOperandImpl() { remove_from_ud_chain(); }
uint32_t ValueImpl::index() const {
uint32_t index =
reinterpret_cast<uintptr_t>(first_use_offseted_by_index_) & 0x07;
if (index < 6) return index;
return reinterpret_cast<OpOutlineResultImpl *>(const_cast<ValueImpl *>(this))
->GetResultIndex();
}
std::string ValueImpl::print_ud_chain() {
std::stringstream result;
result << "Value[" << this << "] -> ";
OpOperandImpl *tmp = first_use();
if (tmp) {
result << "OpOperand[" << reinterpret_cast<void *>(tmp) << "] -> ";
while (tmp->next_use() != nullptr) {
result << "OpOperand[" << reinterpret_cast<void *>(tmp->next_use())
<< "] -> ";
tmp = tmp->next_use();
}
}
result << "nullptr";
return result.str();
}
uint32_t OpResultImpl::GetResultIndex() const {
if (const auto *outline_result = ir::dyn_cast<OpOutlineResultImpl>(this)) {
return outline_result->GetResultIndex();
}
return ir::dyn_cast<OpInlineResultImpl>(this)->GetResultIndex();
}
ir::Operation *OpResultImpl::owner() const {
// For inline result, pointer offset index to obtain the address of op.
if (const auto *result = ir::dyn_cast<OpInlineResultImpl>(this)) {
result += result->GetResultIndex() + 1;
return reinterpret_cast<Operation *>(
const_cast<OpInlineResultImpl *>(result));
}
// For outline result, pointer offset outline_index to obtain the address of
// maximum inline result.
const OpOutlineResultImpl *outline_result =
(const OpOutlineResultImpl *)(this);
outline_result +=
(outline_result->outline_index_ - GetMaxInlineResultIndex());
// The offset of the maximum inline result distance op is
// GetMaxInlineResultIndex.
const auto *inline_result =
reinterpret_cast<const OpInlineResultImpl *>(outline_result);
inline_result += (GetMaxInlineResultIndex() + 1);
return reinterpret_cast<Operation *>(
const_cast<OpInlineResultImpl *>(inline_result));
}
} // namespace detail
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/cast_utils.h"
#include "paddle/ir/type.h"
namespace ir {
class Operation;
namespace detail {
class OpOperandImpl;
class ValueImpl;
class OpResultImpl;
} // namespace detail
///
/// \brief OpOperand class represents the operand of operation. This class only
/// provides interfaces, for specific implementation, see Impl class.
///
class OpOperand {
public:
OpOperand() = default;
OpOperand(const OpOperand &other) = default;
OpOperand(const detail::OpOperandImpl *impl); // NOLINT
OpOperand &operator=(const OpOperand &rhs);
OpOperand &operator=(const detail::OpOperandImpl *impl);
bool operator==(OpOperand other) const;
bool operator!=(OpOperand other) const;
bool operator!() const;
explicit operator bool() const;
detail::OpOperandImpl *impl() const;
private:
detail::OpOperandImpl *impl_{nullptr};
};
///
/// \brief Value class represents the SSA value in the IR system. This class
/// only provides interfaces, for specific implementation, see Impl class.
///
class Value {
public:
Value() = default;
Value(const detail::ValueImpl *impl); // NOLINT
Value(const Value &other) = default;
bool operator==(const Value &other) const;
bool operator!=(const Value &other) const;
bool operator!() const;
explicit operator bool() const;
template <typename T>
bool isa() const {
return ir::isa<T>(*this);
}
template <typename U>
U dyn_cast() const {
return ir::dyn_cast<U>(*this);
}
detail::ValueImpl *impl() const;
ir::Type type() const;
void SetType(ir::Type type);
Operation *GetDefiningOp() const;
std::string print_ud_chain();
friend struct std::hash<Value>;
protected:
detail::ValueImpl *impl_{nullptr};
};
///
/// \brief OpResult class represents the value defined by a result of operation.
/// This class only provides interfaces, for specific implementation, see Impl
/// class.
///
class OpResult : public Value {
public:
using Value::Value;
static bool classof(Value value);
Operation *owner() const;
uint32_t GetResultIndex() const;
friend Operation;
private:
static uint32_t GetValidInlineIndex(uint32_t index);
detail::OpResultImpl *impl() const;
};
} // namespace ir
namespace std {
template <>
struct hash<ir::Value> {
std::size_t operator()(const ir::Value &obj) const {
return std::hash<const ir::detail::ValueImpl *>()(obj.impl_);
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/value.h"
namespace ir {
static const uint32_t OUTLINE_OP_RESULT_INDEX = 6;
class Operation;
namespace detail {
///
/// \brief OpOperandImpl
///
class OpOperandImpl {
public:
ir::Operation *owner() const;
ir::detail::OpOperandImpl *next_use();
/// Remove this operand from the current use list.
void remove_from_ud_chain();
~OpOperandImpl();
friend ir::Operation;
private:
OpOperandImpl(ir::Value source, ir::Operation *owner);
ir::detail::OpOperandImpl *next_use_ = nullptr;
ir::detail::OpOperandImpl **prev_use_addr_ = nullptr;
ir::Value source_;
ir::Operation *owner_ = nullptr;
};
///
/// \brief ValueImpl is the base class of all drived Value classes such as
/// OpResultImpl. This class defines all the information and usage interface in
/// the IR Value. Each Value include three attributes:
/// (1) type: ir::Type; (2) UD-chain of value: OpOperandImpl*, first operand
/// address with offset of this value; (3) index: the position where the output
/// list of the parent operator.
///
class alignas(8) ValueImpl {
public:
///
/// \brief Interface functions of "type_" attribute.
///
ir::Type type() const { return type_; }
void SetType(ir::Type type) { type_ = type; }
///
/// \brief Interface functions of "first_use_offseted_by_index_" attribute.
///
uint32_t index() const;
OpOperandImpl *first_use() const {
return reinterpret_cast<OpOperandImpl *>(
reinterpret_cast<uintptr_t>(first_use_offseted_by_index_) & (~0x07));
}
void SetFirstUse(OpOperandImpl *first_use) {
uint32_t offset = index();
first_use_offseted_by_index_ = reinterpret_cast<OpOperandImpl *>(
reinterpret_cast<uintptr_t>(first_use) + offset);
VLOG(4) << "The index of this value is " << offset
<< ". Offset and set first use: " << first_use << " -> "
<< first_use_offseted_by_index_ << ".";
}
OpOperandImpl **first_use_addr() { return &first_use_offseted_by_index_; }
bool use_empty() const { return first_use() == nullptr; }
std::string print_ud_chain();
protected:
///
/// \brief Only can be constructed by derived classes such as OpResultImpl.
///
explicit ValueImpl(ir::Type type, uint32_t index) {
if (index > OUTLINE_OP_RESULT_INDEX) {
throw("The value of index must not exceed 6");
}
type_ = type;
first_use_offseted_by_index_ = reinterpret_cast<OpOperandImpl *>(
reinterpret_cast<uintptr_t>(nullptr) + index);
VLOG(4) << "Construct a ValueImpl whose's index is " << index
<< ". The offset first_use address is: "
<< first_use_offseted_by_index_;
}
///
/// \brief Attribute1: Type of value.
///
ir::Type type_;
///
/// \brief Attribute2/3: Record the UD-chain of value and index.
/// NOTE: The members of the OpOperandImpl include four pointers, so this
/// class is 8-byte aligned, and the lower 3 bits of its address are 0, so the
/// index can be stored in these 3 bits, stipulate:
/// (1) index = 0~5: represent positions 0 to 5 inline
/// output(OpInlineResultImpl); (2) index = 6: represent the position >=6
/// outline output(OpOutlineResultImpl); (3) index = 7 is reserved.
///
OpOperandImpl *first_use_offseted_by_index_ = nullptr;
};
///
/// \brief OpResultImpl is the implementation of an operation result.
///
class alignas(8) OpResultImpl : public ValueImpl {
public:
using ValueImpl::ValueImpl;
static bool classof(const ValueImpl &value) { return true; }
///
/// \brief Get the parent operation of this result.(op_ptr = value_ptr +
/// index)
///
ir::Operation *owner() const;
///
/// \brief Get the result index of the operation result.
///
uint32_t GetResultIndex() const;
///
/// \brief Get the maximum number of results that can be stored inline.
///
static uint32_t GetMaxInlineResultIndex() {
return OUTLINE_OP_RESULT_INDEX - 1;
}
};
///
/// \brief OpInlineResultImpl is the implementation of an operation result whose
/// index <= 5.
///
class OpInlineResultImpl : public OpResultImpl {
public:
OpInlineResultImpl(ir::Type type, uint32_t result_index)
: OpResultImpl(type, result_index) {
if (result_index > GetMaxInlineResultIndex()) {
throw("Inline result index should not exceed MaxInlineResultIndex(5)");
}
}
static bool classof(const OpResultImpl &value) {
return value.index() < OUTLINE_OP_RESULT_INDEX;
}
uint32_t GetResultIndex() const { return index(); }
};
///
/// \brief OpOutlineResultImpl is the implementation of an operation result
/// whose index > 5.
///
class OpOutlineResultImpl : public OpResultImpl {
public:
OpOutlineResultImpl(ir::Type type, uint32_t outline_index)
: OpResultImpl(type, OUTLINE_OP_RESULT_INDEX),
outline_index_(outline_index) {}
static bool classof(const OpResultImpl &value) {
return value.index() >= OUTLINE_OP_RESULT_INDEX;
}
uint32_t GetResultIndex() const { return outline_index_; }
uint32_t outline_index_;
};
} // namespace detail
} // namespace ir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册