diff --git a/paddle/ir/dialect/shape/CMakeLists.txt b/paddle/ir/dialect/shape/CMakeLists.txt index dd1b708ce9fe44723c60d9aeabf9be9d64a2351d..e0356f14345f474d879d137b421193bffa167e85 100644 --- a/paddle/ir/dialect/shape/CMakeLists.txt +++ b/paddle/ir/dialect/shape/CMakeLists.txt @@ -1 +1,9 @@ -add_subdirectory(ir) +file(GLOB_RECURSE SHAPE_SRCS "*.cc") +ir_library( + ir_shape + SRCS + ${SHAPE_SRCS} + ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc + DEPS + ddim + ir_core) diff --git a/paddle/ir/dialect/shape/ir/CMakeLists.txt b/paddle/ir/dialect/shape/ir/CMakeLists.txt deleted file mode 100644 index ab8ecdd7eda28cc6b2f2337f391c67b18b74db03..0000000000000000000000000000000000000000 --- a/paddle/ir/dialect/shape/ir/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -file(GLOB SHAPE_SRCS "*.cc") -ir_library(ir_shape SRCS ${SHAPE_SRCS} DEPS ir_core) diff --git a/paddle/ir/dialect/shape/ir/shape_op.cc b/paddle/ir/dialect/shape/ir/shape_op.cc index c5368987d4fc3f0e3a431130f810e0bfec666611..4d418403d60a34ff02ab41c82dab6bca5396632a 100644 --- a/paddle/ir/dialect/shape/ir/shape_op.cc +++ b/paddle/ir/dialect/shape/ir/shape_op.cc @@ -54,7 +54,7 @@ void SymbolicDim::Build( argument.AddAttribute("knownNonSizeZero", attr_knownNonSizeZero); } -std::string SymbolicDim::getSymName() { +const std::string SymbolicDim::getSymName() { return attribute("sym_name").AsString(); } int64_t SymbolicDim::getValue() { @@ -103,6 +103,35 @@ void SymbolicDim::updateKnownNonSizeZero(bool attrValue) { ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); } +bool SymbolicDim::isDynamic() { + return getValue() == -100000; +} // TODO(zhangbo): getValue() == ShapedType::kDynamic; + +bool SymbolicDim::merge(SymbolicDim other) { + if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue()) + return false; + if (isDynamic() && !other.isDynamic()) updateValue(other.getValue()); + + bool knownNonNegativeFlag = + getKnownNonNegative() || other.getKnownNonNegative(); + bool knownNegativeOneFlag = + getKnownNegativeOne() || other.getKnownNegativeOne(); + bool knownNonSizeOneFlag = getKnownNonSizeOne() || + other.getKnownNonSizeOne() || knownNegativeOneFlag; + bool knownNonSizeZeroFlag = getKnownNonSizeZero() || + other.getKnownNonSizeZero() || + knownNegativeOneFlag; + + if (knownNonNegativeFlag && knownNegativeOneFlag) return false; + + updateKnownNonSizeZero(knownNonSizeZeroFlag); + updateKnownNonSizeOne(knownNonSizeOneFlag); + updateKnownNegativeOne(knownNegativeOneFlag); + updateKnownNonNegative(knownNonNegativeFlag); + + return true; +} + } // namespace dialect } // namespace ir diff --git a/paddle/ir/dialect/shape/ir/shape_op.h b/paddle/ir/dialect/shape/ir/shape_op.h index 48445d4e8cb75f95e7cfc8a5da3bbe8a87cbef18..d04f0002da536a55b68c44374f712dedf262e1c7 100644 --- a/paddle/ir/dialect/shape/ir/shape_op.h +++ b/paddle/ir/dialect/shape/ir/shape_op.h @@ -37,7 +37,7 @@ class IR_API SymbolicDim : public Op { bool knownNegativeOne = false, bool knownNonSizeOne = false, bool knownNonSizeZero = false); - std::string getSymName(); + const std::string getSymName(); int64_t getValue(); bool getKnownNonNegative(); bool getKnownNegativeOne(); @@ -46,11 +46,14 @@ class IR_API SymbolicDim : public Op { void updateSymName(std::string attrValue); void updateValue(int64_t attrValue); - void updateKnownNonNegative(bool attrValue); void updateKnownNegativeOne(bool attrValue); void updateKnownNonSizeOne(bool attrValue); void updateKnownNonSizeZero(bool attrValue); + + bool isDynamic(); + bool merge(SymbolicDim other); + void Verify() {} }; diff --git a/paddle/ir/dialect/shape/utils/shape_utils.cc b/paddle/ir/dialect/shape/utils/shape_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9676ad5a78b3b24fffa9d351c1860a41ffc75d7 --- /dev/null +++ b/paddle/ir/dialect/shape/utils/shape_utils.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/dialect/shape/utils/shape_utils.h" +#include +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" +namespace ir { + +bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { + if (lhs.size() < 1 || lhs[0] != 'S' && lhs[0] != 'C') return lhs < rhs; + if (rhs.size() < 1 || rhs[0] != 'S' && rhs[0] != 'C') return lhs < rhs; + int64_t lhsIdx = 0, rhsIdx = 0; + try { + lhsIdx = stol(lhs.substr(1)); + rhsIdx = stol(rhs.substr(1)); + } catch (const std::exception& e) { + IR_THROW("Invalid symbolic name"); + } + return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx); +} + +ir::Operation* SymbolTable::lookup(const std::string& name) const { + auto it = symbolTableMap_.find(name); + return it != symbolTableMap_.end() ? it->second : nullptr; +} + +const std::string SymbolTable::insert(ir::Operation* symbol) { + std::string name; + if (symbol->HasAttribute("sym_name")) { + name = symbol->dyn_cast().getSymName(); + } + // TODO(liujinnan): add constraint_func name branch. + symbolTableMap_.insert({name, symbol}); + return name; +} + +const std::string SymbolicDimMgr::getNextName() { + std::string name; + do { + name = "S" + std::to_string(nextSymbolicIdx_++); + } while (!symbolNameSet_.insert(name).second); + return name; +} + +SymbolicDimMgr::SymbolicDimMgr(ir::ModuleOp m) : m_(m), symbolTable_(m_) {} + +SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) { + ::ir::Builder builder = ::ir::Builder(m_.ir_context(), m_.block()); + ir::dialect::SymbolicDim symbol = builder.Build( + name.empty() ? getNextName() : name); + symbolDimUnionSet_[symbol] = symbol; + symbolTable_.insert(symbol); + return symbol; +} + +SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) { + auto it = constantSymbolicDimMap_.find(val); + if (it == constantSymbolicDimMap_.end()) { + auto name = "C" + std::to_string(val); + it = constantSymbolicDimMap_ + .insert(std::make_pair(val, newSymbolicDim(name))) + .first; + it->second.updateValue(val); + } + return getRootSymbolicDim(it->second); +} + +std::vector SymbolicDimMgr::createSymbolicDimsForRankedValue( + ir::Value value) { + std::vector symbols; + auto dims = value.type().dyn_cast().dims(); + for (int idx = 0; idx < dims.size(); ++idx) { + symbols.push_back( + dims[idx] == -100000 // TODO(zhangbo): value = ShapedType::kDynamic + ? newSymbolicDim() + : newConstantSymbolicDim(dims[idx])); + } + return symbols; +} + +SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) { + SymbolicDim current = symbol; + std::vector path; + while (symbolDimUnionSet_[current] != current) { + path.push_back(current); + current = symbolDimUnionSet_[current]; + } + for (SymbolicDim sym : path) symbolDimUnionSet_[sym] = current; + return current; +} + +bool SymbolicDimMgr::isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhsRoot = getRootSymbolicDim(lhs); + SymbolicDim rhsRoot = getRootSymbolicDim(rhs); + return lhsRoot == rhsRoot; +} + +bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhsRoot = getRootSymbolicDim(lhs); + SymbolicDim rhsRoot = getRootSymbolicDim(rhs); + + if (lhsRoot != rhsRoot) { + if (compareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) { + if (!lhsRoot.merge(rhsRoot)) return false; + symbolDimUnionSet_[rhsRoot] = lhsRoot; + } else { + if (!rhsRoot.merge(lhsRoot)) return false; + symbolDimUnionSet_[lhsRoot] = rhsRoot; + } + } + return true; +} + +} // namespace ir diff --git a/paddle/ir/dialect/shape/utils/shape_utils.h b/paddle/ir/dialect/shape/utils/shape_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1f9736f7d3e8b54acb5428e14013212745897120 --- /dev/null +++ b/paddle/ir/dialect/shape/utils/shape_utils.h @@ -0,0 +1,109 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/utils.h" +#include "paddle/ir/dialect/shape/ir/shape_op.h" + +namespace ir { + +using ir::dialect::SymbolicDim; + +struct SymbolicDimProduct { + std::vector symbols; + int64_t factor = 1; + bool empty() { return factor == 1 && symbols.empty(); } +}; + +inline bool operator==(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; +} + +inline bool operator!=(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return !(lhs == rhs); +} + +class SymbolTable { + public: + explicit SymbolTable(ir::Operation* symbolTableOp) + : symbolTableOp_(symbolTableOp) {} + ir::Operation* lookup(const std::string& name) const; + const std::string insert(Operation* symbol); + ir::Operation* getOp() const { return symbolTableOp_; } + + private: + ir::Operation* symbolTableOp_; + std::unordered_map symbolTableMap_; +}; + +struct SymDimHasher { + size_t operator()(const ir::dialect::SymbolicDim& symbol) const noexcept { + return std::hash{}(symbol.operation()); + } +}; + +struct SymProductHasher { + size_t operator()(const ir::SymbolicDimProduct& symProd) const noexcept { + size_t hash = std::hash{}(symProd.symbols.size()); + for (auto& symbol : symProd.symbols) { + hash = hash_combine(hash, SymDimHasher{}(symbol)); // NOLINT + } + hash = hash_combine(hash, std::hash{}(symProd.factor)); + return hash; + } +}; + +class SymbolicDimMgr { + public: + explicit SymbolicDimMgr(ir::ModuleOp m); + SymbolicDim newSymbolicDim(const std::string& name = {}); + SymbolicDim newConstantSymbolicDim(int64_t val); + std::vector createSymbolicDimsForRankedValue(Value value); + SymbolicDim getRootSymbolicDim(SymbolicDim symbol); + bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + SymbolTable& symbolTable() { return symbolTable_; } + bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + + private: + const std::string getNextName(); + + private: + ir::ModuleOp m_; + + SymbolTable symbolTable_; + + int64_t nextSymbolicIdx_ = 0; + + std::unordered_set symbolNameSet_; + + std::unordered_map symbolDimUnionSet_; + + std::unordered_map constantSymbolicDimMap_; + + // productEqualityMap_[A][B] == true : Product[A] == Product[B] + using SymbolicDimProductMap = std::unordered_map< + SymbolicDimProduct, + std::unordered_map, + SymProductHasher>; + SymbolicDimProductMap productEqualityMap_; +}; + +} // namespace ir diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 394320179c3726a0615928a120469bdc8eee2c6f..42c765743a3c73393d60bf029c06c30750bf4a1d 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -39,3 +39,5 @@ collect_srcs( kernel_factory.cc tensor_utils.cc utils/type_info.cc) + +cc_library(ddim SRCS ddim.cc) diff --git a/test/cpp/ir/shape_dialect/assist_struct_test.cc b/test/cpp/ir/shape_dialect/assist_struct_test.cc index ae94fdbda4d11e137cd65cd94bcd9f43b6356cad..2d5216104f5897a980cd6336d7cf429f2cfe138d 100644 --- a/test/cpp/ir/shape_dialect/assist_struct_test.cc +++ b/test/cpp/ir/shape_dialect/assist_struct_test.cc @@ -14,39 +14,132 @@ #include #include +#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/ir/core/block.h" #include "paddle/ir/core/builder.h" +#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/program.h" #include "paddle/ir/dialect/shape/ir/shape_dialect.h" #include "paddle/ir/dialect/shape/ir/shape_op.h" +#include "paddle/ir/dialect/shape/utils/shape_utils.h" TEST(assist_struct_test, symbolic_dim) { ir::IrContext *ctx = ir::IrContext::Instance(); ir::Program program(ctx); ctx->GetOrRegisterDialect(); ir::Builder builder = ir::Builder(ctx, program.block()); - ir::dialect::SymbolicDim sym_dim = builder.Build( + ir::dialect::SymbolicDim symDim = builder.Build( "S0", 10, false, false, false, false); - EXPECT_EQ(sym_dim.getValue(), 10); - EXPECT_EQ(sym_dim.getSymName(), "S0"); - EXPECT_FALSE(sym_dim.getKnownNegativeOne()); - EXPECT_FALSE(sym_dim.getKnownNonSizeOne()); - EXPECT_FALSE(sym_dim.getKnownNonSizeZero()); - EXPECT_FALSE(sym_dim.getKnownNonNegative()); - - sym_dim.updateValue(20); - sym_dim.updateSymName("S1"); - sym_dim.updateKnownNegativeOne(true); - sym_dim.updateKnownNonSizeOne(true); - sym_dim.updateKnownNonSizeZero(true); - sym_dim.updateKnownNonNegative(true); - - EXPECT_EQ(sym_dim.getValue(), 20); - EXPECT_EQ(sym_dim.getSymName(), "S1"); - EXPECT_TRUE(sym_dim.getKnownNegativeOne()); - EXPECT_TRUE(sym_dim.getKnownNonSizeOne()); - EXPECT_TRUE(sym_dim.getKnownNonSizeZero()); - EXPECT_TRUE(sym_dim.getKnownNonNegative()); + ir::dialect::SymbolicDim symDim_ = builder.Build( + "S1", 10, false, false, false, false); + EXPECT_EQ(symDim.getValue(), 10); + EXPECT_EQ(symDim.getSymName(), "S0"); + EXPECT_FALSE(symDim.getKnownNegativeOne()); + EXPECT_FALSE(symDim.getKnownNonSizeOne()); + EXPECT_FALSE(symDim.getKnownNonSizeZero()); + EXPECT_FALSE(symDim.getKnownNonNegative()); + + EXPECT_FALSE(symDim.isDynamic()); + EXPECT_TRUE(symDim.merge(symDim_)); + + symDim.updateValue(20); + symDim.updateSymName("S2"); + symDim.updateKnownNegativeOne(true); + symDim.updateKnownNonSizeOne(true); + symDim.updateKnownNonSizeZero(true); + symDim.updateKnownNonNegative(true); + + EXPECT_FALSE(symDim.merge(symDim_)); + + EXPECT_EQ(symDim.getValue(), 20); + EXPECT_EQ(symDim.getSymName(), "S2"); + EXPECT_TRUE(symDim.getKnownNegativeOne()); + EXPECT_TRUE(symDim.getKnownNonSizeOne()); + EXPECT_TRUE(symDim.getKnownNonSizeZero()); + EXPECT_TRUE(symDim.getKnownNonNegative()); +} + +TEST(assist_struct_test, symbolic_dim_product) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + ir::dialect::SymbolicDim symDim = builder.Build( + "S0", -100000, false, false, false, false); + ir::SymbolicDimProduct symDimProduct; + ir::SymbolicDimProduct symDimProduct_; + symDimProduct.symbols.push_back(symDim); + symDimProduct.factor *= 10; + EXPECT_EQ(symDimProduct.factor, 10); + EXPECT_NE(symDimProduct, symDimProduct_); + EXPECT_FALSE(symDimProduct.empty()); +} + +TEST(assist_struct_test, symbolic_dim_table) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + ir::dialect::SymbolicDim symDim = builder.Build( + "S0", 10, false, false, false, false); + + ir::SymbolTable symbolTable(program.module_op()); + EXPECT_EQ(symbolTable.insert(symDim), "S0"); + EXPECT_EQ(symbolTable.lookup("S0")->dyn_cast(), + symDim); + EXPECT_EQ(symbolTable.lookup("S1"), nullptr); + EXPECT_EQ(symbolTable.getOp(), program.module_op()); +} + +TEST(assist_struct_test, symbolic_dim_mgr) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + ctx->GetOrRegisterDialect(); + + ir::SymbolicDimMgr symDimMgr(program.module_op()); + ir::dialect::SymbolicDim symDimS0 = symDimMgr.newSymbolicDim(); + ir::dialect::SymbolicDim symDimS1 = symDimMgr.newSymbolicDim(); + ir::dialect::SymbolicDim symDimC10 = symDimMgr.newConstantSymbolicDim(10); + symDimMgr.mapSymbolicDimEqual(symDimS0, symDimS1); + + ir::Attribute attr_value = ir::StrAttribute::get(ctx, "op_attr"); + ir::AttributeMap attr_map; + attr_map.insert(std::pair("op", attr_value)); + std::vector op_inputs = {}; + + ir::Type fp32_dtype = ir::Float32Type::get(ctx); + phi::DDim dims = {-100000, 2}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + ir::Operation *op = + ir::Operation::Create(op_inputs, attr_map, op_output_types, ir::OpInfo()); + ir::Value res = op->result(0); + + std::vector symDimVec = + symDimMgr.createSymbolicDimsForRankedValue(res); + + EXPECT_EQ(symDimS0.getSymName(), "S0"); + EXPECT_EQ(symDimS1.getSymName(), "S1"); + EXPECT_EQ(symDimS1.getValue(), -100000); + EXPECT_EQ(symDimC10.getSymName(), "C10"); + EXPECT_EQ(symDimC10.getValue(), 10); + EXPECT_EQ(symDimVec[0].getSymName(), "S2"); + EXPECT_EQ(symDimVec[1].getSymName(), "C2"); + EXPECT_EQ(symDimMgr.symbolTable() + .lookup("S0") + ->dyn_cast(), + symDimS0); + EXPECT_EQ(symDimMgr.symbolTable() + .lookup("C10") + ->dyn_cast(), + symDimC10); + EXPECT_EQ(symDimMgr.getRootSymbolicDim(symDimS1), symDimS0); + EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS1)); + EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10)); }