diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc index ce0f393200c87c85245a340e15e474725a82e78c..31ba23b0e1bbc034776e96705e1aa8a35197b71c 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc @@ -16,18 +16,6 @@ namespace paddle { namespace dialect { -const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } - -const phi::DDim& DenseTensorType::dims() const { return storage()->dims_; } - -const phi::DataLayout& DenseTensorType::data_layout() const { - return storage()->layout_; -} - -const phi::LoD& DenseTensorType::lod() const { return storage()->lod_; } - -const size_t& DenseTensorType::offset() const { return storage()->offset_; } - const ir::Type& SelectedRowsType::dtype() const { return storage()->dtype_; } const phi::DDim& SelectedRowsType::dims() const { return storage()->dims_; } @@ -43,5 +31,4 @@ const size_t& SelectedRowsType::offset() const { return storage()->offset_; } } // namespace dialect } // namespace paddle -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h index 640e4ab2392385f956acd19ebda963adcd29600c..9525e1a88b346ed71ec565b5f0a9e40a6d1756b0 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h @@ -15,30 +15,12 @@ #pragma once #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" +#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/type.h" namespace paddle { namespace dialect { -/// -/// \brief Define built-in parametric types. -/// -class DenseTensorType : public ir::Type { - public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage); - - const ir::Type &dtype() const; - - const phi::DDim &dims() const; - - const phi::DataLayout &data_layout() const; - - const phi::LoD &lod() const; - - const size_t &offset() const; -}; - +using DenseTensorType = ir::DenseTensorType; class SelectedRowsType : public ir::Type { public: using Type::Type; @@ -59,5 +41,4 @@ class SelectedRowsType : public ir::Type { } // namespace dialect } // namespace paddle -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h index fcfff1db5ae855d76ec69ea99c69c447a8d99fbb..1a74b6d6c105926a4a0e3f9b5ed9b62d6df286e1 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h @@ -16,117 +16,15 @@ #include +#include "paddle/ir/core/builtin_type_storage.h" #include "paddle/ir/core/type.h" #include "paddle/ir/core/type_base.h" #include "paddle/ir/core/utils.h" #include "paddle/phi/core/tensor_meta.h" -namespace std { -/// -/// \brief Enable hashing std::vector instances. -/// -template -struct hash> { - std::size_t operator()(const std::vector& dim) const { - std::size_t seed = 0; - for (size_t i = 0; i < dim.size(); ++i) { - seed ^= std::hash()(dim[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; - -} // namespace std - namespace paddle { namespace dialect { -/// -/// \brief Define Parametric TypeStorage for DenseTensorType. -/// -/// NOTE(zhangbo9674): The derived TypeStorage class needs to implement the -/// following methods: (1)declare ParamKey, (2)define Construction method, -/// (3)define HashValue method, (4)overload operator==. -/// -struct DenseTensorTypeStorage : public ir::TypeStorage { - using DataLayout = phi::DataLayout; - using Dim = phi::DDim; - using LoD = std::vector>; - /// - /// \brief Declare ParamKey according to parameter type. - /// - using ParamKey = - std::tuple; - - DenseTensorTypeStorage(const ir::Type& dtype, - const phi::DDim& dims, - const phi::DataLayout& layout, - const phi::LoD& lod, - size_t offset) - : dtype_(dtype), - dims_(dims), - layout_(layout), - lod_(lod), - offset_(offset) {} - - /// - /// \brief Each derived TypeStorage must define a Construct method, which - /// StorageManager uses to construct a derived TypeStorage. - /// - static DenseTensorTypeStorage* Construct(const ParamKey& key) { - return new DenseTensorTypeStorage(std::get<0>(key), - std::get<1>(key), - std::get<2>(key), - std::get<3>(key), - std::get<4>(key)); - } - - /// - /// \brief Each derived TypeStorage must provide a HashValue method. - /// - static std::size_t HashValue(const ParamKey& key) { - std::size_t hash_value = 0; - // hash dtype - hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<0>(key))); - // hash dims - hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<1>(key))); - // hash layout - hash_value = ir::hash_combine( - hash_value, - std::hash::type>()( - static_cast::type>( - std::get<2>(key)))); - // hash lod - hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<3>(key))); - // hash offset - hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<4>(key))); - return hash_value; - } - - /// - /// \brief Each derived TypeStorage needs to overload operator==. - /// - bool operator==(const ParamKey& key) const { - return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key; - } - - ParamKey GetAsKey() const { - return ParamKey(dtype_, dims_, layout_, lod_, offset_); - } - - /// - /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, - /// layout, lod, offset. - /// - ir::Type dtype_; - phi::DDim dims_; - phi::DataLayout layout_; - phi::LoD lod_; - size_t offset_; -}; +using DenseTensorTypeStorage = ir::DenseTensorTypeStorage; struct SelectedRowsTypeStorage : public ir::TypeStorage { using DataLayout = phi::DataLayout; diff --git a/paddle/ir/core/CMakeLists.txt b/paddle/ir/core/CMakeLists.txt index c35bc02b344fadc70a533a037930edd5ba758ed0..39a3d71b2712894877bfb51439133cce6c7678c8 100644 --- a/paddle/ir/core/CMakeLists.txt +++ b/paddle/ir/core/CMakeLists.txt @@ -3,4 +3,4 @@ set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir") file(GLOB IR_SRCS "*.cc") -ir_library(ir_core SRCS ${IR_SRCS}) +ir_library(ir_core SRCS ${IR_SRCS} DEPS ddim) diff --git a/paddle/ir/core/builtin_type.cc b/paddle/ir/core/builtin_type.cc index 8a0aea5745a5b2a3535fac5bbdd80d8a07adbd07..49a15484466b259fdfc9c0d795b1e7c4b3b37032 100644 --- a/paddle/ir/core/builtin_type.cc +++ b/paddle/ir/core/builtin_type.cc @@ -17,6 +17,21 @@ namespace ir { std::vector VectorType::data() const { return storage()->GetAsKey(); } +const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } + +const DenseTensorTypeStorage::Dim& DenseTensorType::dims() const { + return storage()->dims_; +} + +const DenseTensorTypeStorage::DataLayout& DenseTensorType::data_layout() const { + return storage()->layout_; +} + +const DenseTensorTypeStorage::LoD& DenseTensorType::lod() const { + return storage()->lod_; +} + +const size_t& DenseTensorType::offset() const { return storage()->offset_; } } // namespace ir IR_DEFINE_EXPLICIT_TYPE_ID(ir::UInt8Type) @@ -33,3 +48,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::IndexType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::DenseTensorType) diff --git a/paddle/ir/core/builtin_type.h b/paddle/ir/core/builtin_type.h index 9a2939110deaca170db1cb6618e3d2e2ab9327ab..a660f065376b2e90ee615da1436c95de88233bec 100644 --- a/paddle/ir/core/builtin_type.h +++ b/paddle/ir/core/builtin_type.h @@ -1,3 +1,4 @@ + // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -53,6 +54,23 @@ class IR_API VectorType : public Type { Type operator[](size_t index) const { return data()[index]; } }; +class DenseTensorType : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage); + + const ir::Type &dtype() const; + + const DenseTensorTypeStorage::Dim &dims() const; + + const DenseTensorTypeStorage::DataLayout &data_layout() const; + + const DenseTensorTypeStorage::LoD &lod() const; + + const size_t &offset() const; +}; + #define DECLARE_BUILTIN_TYPE(__name) \ class IR_API __name : public Type { \ public: \ @@ -99,3 +117,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::IndexType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::DenseTensorType) diff --git a/paddle/ir/core/builtin_type_storage.h b/paddle/ir/core/builtin_type_storage.h index 64d0c4284aa7f1a602307ca098379da678ca270a..4488b28b07fa22b76e1d501e56dc2aa56efbd497 100644 --- a/paddle/ir/core/builtin_type_storage.h +++ b/paddle/ir/core/builtin_type_storage.h @@ -17,15 +17,122 @@ #include "paddle/ir/core/type.h" #include "paddle/ir/core/type_base.h" #include "paddle/ir/core/utils.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/ddim.h" + +namespace std { +/// +/// \brief Enable hashing std::vector instances. +/// +template +struct hash> { + std::size_t operator()(const std::vector& dim) const { + std::size_t seed = 0; + for (size_t i = 0; i < dim.size(); ++i) { + seed ^= std::hash()(dim[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +} // namespace std namespace ir { +/// +/// \brief Define Parametric TypeStorage for DenseTensorType. +/// +/// NOTE(zhangbo9674): The derived TypeStorage class needs to implement the +/// following methods: (1)declare ParamKey, (2)define Construction method, +/// (3)define HashValue method, (4)overload operator==. +/// + +struct DenseTensorTypeStorage : public ir::TypeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using DataLayout = phi::DataLayout; + using Dim = phi::DDim; + using LoD = std::vector>; + using ParamKey = std::tuple; + + DenseTensorTypeStorage(const ir::Type& dtype, + const Dim& dims, + const DataLayout& layout, + const LoD& lod, + size_t offset) + : dtype_(dtype), + dims_(dims), + layout_(layout), + lod_(lod), + offset_(offset) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static DenseTensorTypeStorage* Construct(const ParamKey& key) { + return new DenseTensorTypeStorage(std::get<0>(key), + std::get<1>(key), + std::get<2>(key), + std::get<3>(key), + std::get<4>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + std::size_t hash_value = 0; + // hash dtype + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<0>(key))); + // hash dims + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<1>(key))); + // hash layout + hash_value = ir::hash_combine( + hash_value, + std::hash::type>()( + static_cast::type>( + std::get<2>(key)))); + // hash lod + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<3>(key))); + // hash offset + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<4>(key))); + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key; + } + + ParamKey GetAsKey() const { + return ParamKey(dtype_, dims_, layout_, lod_, offset_); + } + + /// + /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, + /// layout, lod, offset. + /// + ir::Type dtype_; + Dim dims_; + DataLayout layout_; + LoD lod_; + size_t offset_; +}; + struct VectorTypeStorage : public TypeStorage { using ParamKey = std::vector; - explicit VectorTypeStorage(const ParamKey &key) { - data_ = reinterpret_cast(malloc(key.size() * sizeof(Type))); - memcpy(reinterpret_cast(data_), - reinterpret_cast(key.data()), + explicit VectorTypeStorage(const ParamKey& key) { + data_ = reinterpret_cast(malloc(key.size() * sizeof(Type))); + memcpy(reinterpret_cast(data_), + reinterpret_cast(key.data()), key.size() * sizeof(Type)); size_ = key.size(); } @@ -36,14 +143,14 @@ struct VectorTypeStorage : public TypeStorage { /// \brief Each derived TypeStorage must define a Construc method, which /// StorageManager uses to construct a derived TypeStorage. /// - static VectorTypeStorage *Construct(const ParamKey &key) { + static VectorTypeStorage* Construct(const ParamKey& key) { return new VectorTypeStorage(key); } /// /// \brief Each derived TypeStorage must provide a HashValue method. /// - static std::size_t HashValue(const ParamKey &key) { + static std::size_t HashValue(const ParamKey& key) { std::size_t hash_value = 0; for (size_t i = 0; i < key.size(); ++i) { hash_value = hash_combine(hash_value, std::hash()(key[i])); @@ -54,7 +161,7 @@ struct VectorTypeStorage : public TypeStorage { /// /// \brief Each derived TypeStorage needs to overload operator==. /// - bool operator==(const ParamKey &key) const { + bool operator==(const ParamKey& key) const { if (key.size() != size_) { return false; } @@ -72,7 +179,7 @@ struct VectorTypeStorage : public TypeStorage { /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, /// layout, lod, offset. /// - Type *data_; + Type* data_; size_t size_; }; diff --git a/paddle/ir/dialect/shape/CMakeLists.txt b/paddle/ir/dialect/shape/CMakeLists.txt index e0356f14345f474d879d137b421193bffa167e85..62d7c0d42c85c8a78814851fd838dd8179259e61 100644 --- a/paddle/ir/dialect/shape/CMakeLists.txt +++ b/paddle/ir/dialect/shape/CMakeLists.txt @@ -1,9 +1,2 @@ file(GLOB_RECURSE SHAPE_SRCS "*.cc") -ir_library( - ir_shape - SRCS - ${SHAPE_SRCS} - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc - DEPS - ddim - ir_core) +ir_library(ir_shape SRCS ${SHAPE_SRCS} DEPS ir_core) diff --git a/paddle/ir/dialect/shape/ir/shape_dialect.cc b/paddle/ir/dialect/shape/ir/shape_dialect.cc index c2fd60d88a4783c14bd02909f5effe0ee449b340..d058924511bcd59b76d32b86a09b51deec19c09d 100644 --- a/paddle/ir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/ir/dialect/shape/ir/shape_dialect.cc @@ -22,7 +22,9 @@ ShapeDialect::ShapeDialect(IrContext *context) initialize(); } -void ShapeDialect::initialize() { RegisterOps(); } +void ShapeDialect::initialize() { + RegisterOps(); +} } // namespace dialect } // namespace ir diff --git a/paddle/ir/dialect/shape/ir/shape_op.cc b/paddle/ir/dialect/shape/ir/shape_op.cc index 4d418403d60a34ff02ab41c82dab6bca5396632a..3681aafa36520bea1665d3b1d068853f2a44fa8e 100644 --- a/paddle/ir/dialect/shape/ir/shape_op.cc +++ b/paddle/ir/dialect/shape/ir/shape_op.cc @@ -14,6 +14,7 @@ #include "paddle/ir/dialect/shape/ir/shape_op.h" #include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_type.h" namespace ir { namespace dialect { @@ -132,7 +133,65 @@ bool SymbolicDim::merge(SymbolicDim other) { return true; } +const char *DimOp::attributes_name[attributes_num] = {"name"}; // NOLINT + +void DimOp::Build(Builder &builder, + OperationArgument &argument, + const std::string &name) { + ir::Attribute attr_name = + ir::StrAttribute::get(ir::IrContext::Instance(), name); + argument.AddAttribute("name", attr_name); + argument.output_types.emplace_back( + ir::IndexType::get(ir::IrContext::Instance())); +} + +const std::string DimOp::getName() { + return attribute("name").AsString(); +} + +void DimOp::setName(std::string attrName) { + operation()->set_attribute( + "name", ir::StrAttribute::get(ir::IrContext::Instance(), attrName)); +} + +const char *TieProductEqualOp::attributes_name[attributes_num] = { + "lhs_len", "rhs_len"}; // NOLINT + +void TieProductEqualOp::Build(Builder &builder, + OperationArgument &argument, + int64_t lhs_len, + int64_t rhs_len, + const std::vector &inputs) { + ir::Attribute attr_lhs_len = + ir::Int64Attribute::get(ir::IrContext::Instance(), lhs_len); + argument.AddAttribute("lhs_len", attr_lhs_len); + ir::Attribute attr_rhs_len = + ir::Int64Attribute::get(ir::IrContext::Instance(), rhs_len); + argument.AddAttribute("rhs_len", attr_rhs_len); + argument.inputs = inputs; +} + +std::vector TieProductEqualOp::getLhs() { + int64_t lhs_len = attribute("lhs_len").data(); + std::vector res; + for (uint32_t idx = 0; idx < lhs_len; idx++) { + res.push_back(operand_source(idx)); + } + return res; +} +std::vector TieProductEqualOp::getRhs() { + int64_t lhs_len = attribute("lhs_len").data(); + int64_t rhs_len = attribute("rhs_len").data(); + std::vector res; + for (uint32_t idx = 0; idx < rhs_len; idx++) { + res.push_back(operand_source(lhs_len + idx)); + } + return res; +} + } // namespace dialect } // namespace ir IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::DimOp) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::TieProductEqualOp) diff --git a/paddle/ir/dialect/shape/ir/shape_op.h b/paddle/ir/dialect/shape/ir/shape_op.h index d04f0002da536a55b68c44374f712dedf262e1c7..af61393a24c9b3d22028ec256f06afb936bd5b2a 100644 --- a/paddle/ir/dialect/shape/ir/shape_op.h +++ b/paddle/ir/dialect/shape/ir/shape_op.h @@ -57,7 +57,45 @@ class IR_API SymbolicDim : public Op { void Verify() {} }; +class IR_API DimOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.dim"; } + + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::string &name); + + const std::string getName(); + void setName(std::string attrValue); + ir::OpResult out() { return result(0); } + void Verify() {} +}; + +class IR_API TieProductEqualOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.tie_product_equal"; } + + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; + // attr operand_segment_sizes + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + int64_t lhs_len, + int64_t rhs_len, + const std::vector &inputs); + std::vector getLhs(); + std::vector getRhs(); + void Verify() {} +}; + } // namespace dialect } // namespace ir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::DimOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::TieProductEqualOp); diff --git a/paddle/ir/dialect/shape/utils/shape_utils.cc b/paddle/ir/dialect/shape/utils/shape_utils.cc index d9676ad5a78b3b24fffa9d351c1860a41ffc75d7..182d335f71c3d4cf3f845ff92b53f78d216625c6 100644 --- a/paddle/ir/dialect/shape/utils/shape_utils.cc +++ b/paddle/ir/dialect/shape/utils/shape_utils.cc @@ -18,8 +18,8 @@ namespace ir { bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { - if (lhs.size() < 1 || lhs[0] != 'S' && lhs[0] != 'C') return lhs < rhs; - if (rhs.size() < 1 || rhs[0] != 'S' && rhs[0] != 'C') return lhs < rhs; + if (lhs.size() < 1 || (lhs[0] != 'S' && lhs[0] != 'C')) return lhs < rhs; + if (rhs.size() < 1 || (rhs[0] != 'S' && rhs[0] != 'C')) return lhs < rhs; int64_t lhsIdx = 0, rhsIdx = 0; try { lhsIdx = stol(lhs.substr(1)); @@ -30,18 +30,19 @@ bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx); } -ir::Operation* SymbolTable::lookup(const std::string& name) const { - auto it = symbolTableMap_.find(name); - return it != symbolTableMap_.end() ? it->second : nullptr; -} - const std::string SymbolTable::insert(ir::Operation* symbol) { std::string name; - if (symbol->HasAttribute("sym_name")) { + if (symbol->name() == "shape.SymbolicDim") { name = symbol->dyn_cast().getSymName(); + symbolTableMap_.insert({name, symbol}); } - // TODO(liujinnan): add constraint_func name branch. - symbolTableMap_.insert({name, symbol}); + + // TODO(liujinnan): add more constraint_func name branch. + if (symbol->name() == "shape.tie_product_equal") { + name = "tie_product_equal"; + symbolFuncMap_[name].emplace_back(symbol); + } + return name; } diff --git a/paddle/ir/dialect/shape/utils/shape_utils.h b/paddle/ir/dialect/shape/utils/shape_utils.h index 1f9736f7d3e8b54acb5428e14013212745897120..70f2a16c4481e8c30564983249f59210410291e4 100644 --- a/paddle/ir/dialect/shape/utils/shape_utils.h +++ b/paddle/ir/dialect/shape/utils/shape_utils.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include "paddle/ir/core/builtin_op.h" @@ -29,29 +30,51 @@ struct SymbolicDimProduct { std::vector symbols; int64_t factor = 1; bool empty() { return factor == 1 && symbols.empty(); } -}; - -inline bool operator==(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; -} + friend inline bool operator==(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; + } -inline bool operator!=(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - return !(lhs == rhs); -} + friend inline bool operator!=(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return !(lhs == rhs); + } +}; class SymbolTable { public: explicit SymbolTable(ir::Operation* symbolTableOp) : symbolTableOp_(symbolTableOp) {} - ir::Operation* lookup(const std::string& name) const; + + template + typename std::enable_if::value, + SymbolicDim>::type + lookup(const std::string& name) const { + auto it = symbolTableMap_.find(name); + return it != symbolTableMap_.end() ? it->second->dyn_cast() + : SymbolicDim(nullptr); + } + template + typename std::enable_if::value, + std::vector>::type + lookup(const std::string& name) const { + std::vector res; + auto it = symbolFuncMap_.find(name); + if (it != symbolFuncMap_.end()) { + for (auto& p : it->second) { + res.push_back(p->dyn_cast()); + } + } + return res; + } + const std::string insert(Operation* symbol); ir::Operation* getOp() const { return symbolTableOp_; } private: ir::Operation* symbolTableOp_; std::unordered_map symbolTableMap_; + std::unordered_map> symbolFuncMap_; }; struct SymDimHasher { diff --git a/test/cpp/ir/shape_dialect/symbolic_op_test.cc b/test/cpp/ir/shape_dialect/symbolic_op_test.cc index 5dfcf19c22f34df240c309660c0ed0c5968e4449..7b0751d17ac138666947f353f6e6dbbea5b6536f 100644 --- a/test/cpp/ir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/ir/shape_dialect/symbolic_op_test.cc @@ -88,10 +88,9 @@ TEST(assist_struct_test, symbolic_dim_table) { ir::SymbolTable symbolTable(program.module_op()); EXPECT_EQ(symbolTable.insert(symDim), "S0"); - EXPECT_EQ(symbolTable.lookup("S0")->dyn_cast(), - symDim); - EXPECT_EQ(symbolTable.lookup("S1"), nullptr); + EXPECT_EQ(symbolTable.lookup("S0"), symDim); EXPECT_EQ(symbolTable.getOp(), program.module_op()); + EXPECT_FALSE(symbolTable.lookup("S1")); } TEST(assist_struct_test, symbolic_dim_mgr) { @@ -133,15 +132,63 @@ TEST(assist_struct_test, symbolic_dim_mgr) { EXPECT_EQ(symDimC10.getValue(), 10); EXPECT_EQ(symDimVec[0].getSymName(), "S2"); EXPECT_EQ(symDimVec[1].getSymName(), "C2"); - EXPECT_EQ(symDimMgr.symbolTable() - .lookup("S0") - ->dyn_cast(), + EXPECT_EQ(symDimMgr.symbolTable().lookup("S0"), symDimS0); - EXPECT_EQ(symDimMgr.symbolTable() - .lookup("C10") - ->dyn_cast(), + EXPECT_EQ(symDimMgr.symbolTable().lookup("C10"), symDimC10); EXPECT_EQ(symDimMgr.getRootSymbolicDim(symDimS1), symDimS0); EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS1)); EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10)); } + +TEST(assist_struct_test, dim) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + + ir::dialect::DimOp dimOp = builder.Build("S0"); + ir::OpResult res = dimOp.out(); + EXPECT_EQ(dimOp.getName(), "S0"); + dimOp.setName("S1"); + EXPECT_EQ(dimOp.getName(), "S1"); + EXPECT_EQ(res.GetDefiningOp(), dimOp.operation()); + EXPECT_EQ(res.type(), ir::IndexType::get(ctx)); +} + +TEST(assist_struct_test, tie_product_equal) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + ir::SymbolTable symbolTable(program.module_op()); + + ir::OpResult dimOp0 = builder.Build("S0").out(); + ir::OpResult dimOp1 = builder.Build("S1").out(); + ir::OpResult dimOp2 = builder.Build("S2").out(); + ir::OpResult dimOp3 = builder.Build("S3").out(); + ir::OpResult dimOp4 = builder.Build("S4").out(); + + ir::dialect::TieProductEqualOp tie_product_equal = + builder.Build( + 2, + 3, + std::vector{dimOp0, dimOp1, dimOp2, dimOp3, dimOp4}); + + std::vector lhs = tie_product_equal.getLhs(); + std::vector rhs = tie_product_equal.getRhs(); + + std::vector lhs_ref{dimOp0, dimOp1}; + std::vector rhs_ref{dimOp2, dimOp3, dimOp4}; + + EXPECT_EQ(symbolTable.insert(tie_product_equal), "tie_product_equal"); + EXPECT_EQ( + symbolTable.lookup("tie_product_equal") + .size(), + static_cast(1)); + EXPECT_EQ(symbolTable.lookup( + "tie_product_equal")[0], + tie_product_equal); + EXPECT_EQ(lhs, lhs_ref); + EXPECT_EQ(rhs, rhs_ref); +}