From 14ede2b94575c07e23e20e0f8a58ef8d9c1b3459 Mon Sep 17 00:00:00 2001 From: liuruyan <44316842+liuruyan@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:11:43 +0800 Subject: [PATCH] Add constraint related func into SymDimMgr. (#56727) * add symbolicDimProduct & symbolicDimMgr without method shape_constraint related. * split ddim in phi, add a target ddim, used by pd_type. * add pd_type.cc to ir_shape CMakeLists. * add dimOp, tieProductEqualOp. access constraint_func in SymbolTable. * put DenseTensorType into builtin_type. * add constraint related Mgr func. * move to out assert. --- paddle/ir/dialect/shape/ir/shape_op.cc | 1 + paddle/ir/dialect/shape/utils/shape_utils.cc | 298 ++++++++++++++++++ paddle/ir/dialect/shape/utils/shape_utils.h | 24 +- test/cpp/ir/shape_dialect/symbolic_op_test.cc | 174 +++++++++- 4 files changed, 494 insertions(+), 3 deletions(-) diff --git a/paddle/ir/dialect/shape/ir/shape_op.cc b/paddle/ir/dialect/shape/ir/shape_op.cc index 3681aafa365..776503ea269 100644 --- a/paddle/ir/dialect/shape/ir/shape_op.cc +++ b/paddle/ir/dialect/shape/ir/shape_op.cc @@ -112,6 +112,7 @@ bool SymbolicDim::merge(SymbolicDim other) { if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue()) return false; if (isDynamic() && !other.isDynamic()) updateValue(other.getValue()); + if (!isDynamic() && other.isDynamic()) other.updateValue(getValue()); bool knownNonNegativeFlag = getKnownNonNegative() || other.getKnownNonNegative(); diff --git a/paddle/ir/dialect/shape/utils/shape_utils.cc b/paddle/ir/dialect/shape/utils/shape_utils.cc index 182d335f71c..f9d78a63184 100644 --- a/paddle/ir/dialect/shape/utils/shape_utils.cc +++ b/paddle/ir/dialect/shape/utils/shape_utils.cc @@ -46,6 +46,154 @@ const std::string SymbolTable::insert(ir::Operation* symbol) { return name; } +bool SymbolicDimMgr::load() { + for (auto op_it = m_.block()->begin(); op_it != m_.block()->end(); op_it++) { + symbolTable_.insert(*op_it); + SymbolicDim op = (*op_it)->dyn_cast(); + if (!op) continue; + symbolDimUnionSet_[op] = op; + symbolNameSet_.insert(op.getSymName()); + } + return loadShapeConstraintGraph(); +} + +bool SymbolicDimMgr::loadShapeConstraintGraph() { + // TODO(liujinnan): add more constraint function. currently, only support + // tie_product_equal. + auto constraint_vec = + symbolTable_.lookup("tie_product_equal"); + + if (!constraint_vec.size()) return true; + + auto build_sym_product = [&](std::vector range, + SymbolicDimProduct& product) { + for (Value v : range) { + auto definingOp = v.GetDefiningOp(); + if (auto constOp = definingOp->dyn_cast()) { + product.factor *= constOp.value().dyn_cast().data(); + continue; + } else if (auto dimOp = definingOp->dyn_cast()) { + auto sym = symbolTable_.lookup(dimOp.getName()); + if (!sym) return false; + product.symbols.push_back(sym); + continue; + } + return false; + } + return true; + }; + for (auto op : constraint_vec) { + SymbolicDimProduct lhs, rhs; + if (!build_sym_product(op.getLhs(), lhs) || + !build_sym_product(op.getRhs(), rhs) || + !mapSymbolicDimProductEqual(lhs, rhs)) + return false; + } + return true; +} + +int64_t gcd(int64_t m, int64_t n) { + if (!m) return n; + if (!n) return m; + return (m < n) ? gcd(m, n % m) : gcd(m % n, n); +} + +bool SymbolicDimMgr::mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + SymbolicDimProduct newLhs, newRhs; + std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + + // early return for identity case. + if (newLhs == newRhs) return true; + + if (newLhs.factor == newRhs.factor && newLhs.symbols.size() == 1 && + newRhs.symbols.size() == 1) { + return mapSymbolicDimEqual(newLhs.symbols[0], newRhs.symbols[0]); + } else if (newLhs.symbols.size() == 0 && newRhs.symbols.size() == 1 && + newRhs.factor == 1) { + return mapSymbolicDimEqual(newConstantSymbolicDim(newLhs.factor), + newRhs.symbols[0]); + } else if (newRhs.symbols.size() == 0 && newLhs.symbols.size() == 1 && + newLhs.factor == 1) { + return mapSymbolicDimEqual(newConstantSymbolicDim(newRhs.factor), + newLhs.symbols[0]); + } + + productEqualityMap_[newLhs][newRhs] = productEqualityMap_[newRhs][newLhs] = + true; + + productEqualityMapUpdated_ = false; + return true; +} + +std::pair +SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, + const SymbolicDimProduct& y) { + auto lhs = simplifySymbolicDimProduct(x); + auto rhs = simplifySymbolicDimProduct(y); + + SymbolicDimProduct newLhs, newRhs; + int64_t gcdFactor = gcd(std::abs(lhs.factor), std::abs(rhs.factor)); + if (!gcdFactor) return std::make_pair(std::move(newLhs), std::move(newRhs)); + if (std::abs(lhs.factor) < std::abs(rhs.factor)) { + if (lhs.factor < 0) gcdFactor = -gcdFactor; + } else { + if (rhs.factor < 0) gcdFactor = -gcdFactor; + } + + newLhs.factor = lhs.factor / gcdFactor; + newRhs.factor = rhs.factor / gcdFactor; + + std::unordered_map lhsSymbolMap; + std::unordered_map rhsSymbolMap; + for (SymbolicDim op : lhs.symbols) ++lhsSymbolMap[op]; + for (SymbolicDim op : rhs.symbols) ++rhsSymbolMap[op]; + + for (SymbolicDim op : lhs.symbols) { + auto it = rhsSymbolMap.find(op); + if (it != rhsSymbolMap.end() && op.getKnownNonSizeZero()) { + if (--it->second == 0) rhsSymbolMap.erase(it); + continue; + } + newLhs.symbols.push_back(op); + } + + for (SymbolicDim op : rhs.symbols) { + auto it = lhsSymbolMap.find(op); + if (it != lhsSymbolMap.end() && op.getKnownNonSizeZero()) { + if (--it->second == 0) lhsSymbolMap.erase(it); + continue; + } + newRhs.symbols.push_back(op); + } + + if (!newLhs.factor) newLhs.symbols.clear(); + if (!newRhs.factor) newRhs.symbols.clear(); + + return std::make_pair(std::move(newLhs), std::move(newRhs)); +} + +SymbolicDimProduct SymbolicDimMgr::simplifySymbolicDimProduct( + const SymbolicDimProduct& x) { + std::vector copied; + copied.reserve(x.symbols.size()); + for (SymbolicDim op : x.symbols) copied.push_back(getRootSymbolicDim(op)); + + sort(copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) { + return compareSymbolicDimNames(lhs.getSymName(), rhs.getSymName()); + }); + SymbolicDimProduct newX; + newX.factor = x.factor; + for (SymbolicDim op : copied) { + if (!op.isDynamic()) { + newX.factor *= op.getValue(); + } else { + newX.symbols.push_back(op); + } + } + return newX; +} + const std::string SymbolicDimMgr::getNextName() { std::string name; do { @@ -123,4 +271,154 @@ bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { return true; } +SymbolicDimProduct* SymbolicDimMgr::symbolicDimProductDivide( + const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { + SymbolicDimProduct newLhs, newRhs; + std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + + if (newLhs.factor == 0 || newRhs.factor == 0) return nullptr; + if (newLhs.factor % newRhs.factor != 0) return nullptr; + if (newLhs.symbols.size() < newRhs.symbols.size()) return nullptr; + + SymbolicDimProduct* result = new SymbolicDimProduct(); + result->factor = newLhs.factor / newRhs.factor; + + std::unordered_map symProcMap; + for (SymbolicDim sym : newRhs.symbols) ++symProcMap[sym]; + + for (SymbolicDim sym : newLhs.symbols) { + auto it = symProcMap.find(sym); + if (it == symProcMap.end()) { + result->symbols.push_back(sym); + continue; + } + if (--it->second == 0) { + symProcMap.erase(it); + continue; + } + } + + if (!symProcMap.empty()) return nullptr; + return result; +} + +bool SymbolicDimMgr::isMultipleOfKnownSymbolicDimProductEqualPair( + const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { + for (auto& pairOutter : productEqualityMap_) { + const SymbolicDimProduct& x = pairOutter.first; + auto factorX = symbolicDimProductDivide(lhs, x); + if (!factorX) continue; + for (auto& pairInner : pairOutter.second) { + if (!pairInner.second) continue; + const SymbolicDimProduct& y = pairInner.first; + auto factorY = symbolicDimProductDivide(rhs, y); + if (!factorY || (*factorX) != (*factorY)) continue; + return true; + } + } + + return false; +} + +bool SymbolicDimMgr::updateProductEqualityMap() { + // early return if nothing is updated. + if (productEqualityMapUpdated_) return true; + + SymbolicDimProductMap newMap; + std::unordered_set productSet; + for (auto& pairOutter : productEqualityMap_) { + const SymbolicDimProduct& x = pairOutter.first; + for (auto& pairInner : pairOutter.second) { + if (!pairInner.second) continue; + const SymbolicDimProduct& y = pairInner.first; + SymbolicDimProduct newX, newY; + std::tie(newX, newY) = simplifySymbolicDimProductPair(x, y); + if (newX == newY) continue; + newMap[newX][newY] = newMap[newY][newX] = true; + productSet.insert(newX); + productSet.insert(newY); + } + } + // hash function of SymbolicDimProduct is expensive, thus we map it to integer + // domain first. + std::unordered_map symProd2Idx; + std::vector idx2SymProd(productSet.size()); + std::vector idx2root(productSet.size()); + for (auto& x : productSet) { + size_t idx = symProd2Idx.size(); + symProd2Idx[&x] = idx; + idx2SymProd[idx] = &x; + idx2root[idx] = idx; + } + + auto getRootIdx = [&](size_t root) { + std::vector path; + while (idx2root[root] != root) { + path.push_back(root); + root = idx2root[root]; + } + for (size_t idx : path) idx2root[idx] = root; + return root; + }; + + for (size_t x = 0; x < symProd2Idx.size(); ++x) { + auto& xProd = *idx2SymProd[x]; + auto& rowMap = newMap[xProd]; + size_t xRoot = getRootIdx(x); + for (size_t y = x; y < symProd2Idx.size(); ++y) { + auto& yProd = *idx2SymProd[y]; + if (!rowMap[yProd]) continue; + idx2root[getRootIdx(y)] = xRoot; + } + } + + for (size_t x = 0; x < symProd2Idx.size(); ++x) + for (size_t y = x; y < symProd2Idx.size(); ++y) { + if (getRootIdx(x) != getRootIdx(y)) continue; + auto& xSymProd = *idx2SymProd[x]; + auto& ySymProd = *idx2SymProd[y]; + + newMap[xSymProd][ySymProd] = newMap[ySymProd][xSymProd] = true; + } + + productEqualityMap_ = std::move(newMap); + + for (auto& x : productSet) + for (auto& y : productSet) { + if (!productEqualityMap_[x][y]) continue; + productEqualityMap_[x][y] = productEqualityMap_[y][x] = false; + if (!isMultipleOfKnownSymbolicDimProductEqualPair(x, y)) { + productEqualityMap_[x][y] = productEqualityMap_[y][x] = true; + } + } + + std::unordered_set toRemove; + for (auto& x : productSet) { + if (std::all_of(productSet.begin(), + productSet.end(), + [&](const SymbolicDimProduct& y) { + return !productEqualityMap_[x][y]; + })) { + toRemove.insert(x); + } + } + + for (auto& x : toRemove) { + productEqualityMap_.erase(x); + } + + productEqualityMapUpdated_ = true; + return true; +} + +bool SymbolicDimMgr::isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + SymbolicDimProduct newLhs, newRhs; + std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + + // early return for identity case. + if (newLhs == newRhs) return true; + IR_ENFORCE(updateProductEqualityMap(), "Update product equality map failed."); + return isMultipleOfKnownSymbolicDimProductEqualPair(newLhs, newRhs); +} } // namespace ir diff --git a/paddle/ir/dialect/shape/utils/shape_utils.h b/paddle/ir/dialect/shape/utils/shape_utils.h index 70f2a16c448..8d5fab1a1c8 100644 --- a/paddle/ir/dialect/shape/utils/shape_utils.h +++ b/paddle/ir/dialect/shape/utils/shape_utils.h @@ -18,6 +18,7 @@ #include #include #include +#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/utils.h" #include "paddle/ir/dialect/shape/ir/shape_op.h" @@ -45,7 +46,6 @@ class SymbolTable { public: explicit SymbolTable(ir::Operation* symbolTableOp) : symbolTableOp_(symbolTableOp) {} - template typename std::enable_if::value, SymbolicDim>::type @@ -97,6 +97,7 @@ struct SymProductHasher { class SymbolicDimMgr { public: explicit SymbolicDimMgr(ir::ModuleOp m); + bool load(); SymbolicDim newSymbolicDim(const std::string& name = {}); SymbolicDim newConstantSymbolicDim(int64_t val); std::vector createSymbolicDimsForRankedValue(Value value); @@ -104,9 +105,28 @@ class SymbolicDimMgr { bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); SymbolTable& symbolTable() { return symbolTable_; } bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + SymbolicDimProduct simplifySymbolicDimProduct(const SymbolicDimProduct& x); + std::pair + simplifySymbolicDimProductPair(const SymbolicDimProduct& x, + const SymbolicDimProduct& y); + SymbolicDimProduct* symbolicDimProductDivide(const SymbolicDimProduct& x, + const SymbolicDimProduct& y); + + bool save(); // TODO(liujinnan): load constraint func + + bool isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs); + bool mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs); private: const std::string getNextName(); + bool updateProductEqualityMap(); + bool isMultipleOfKnownSymbolicDimProductEqualPair( + const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); + bool saveShapeConstraintGraph(); // TODO(liujinnan): load & save + // shape_constraint_func + bool loadShapeConstraintGraph(); private: ir::ModuleOp m_; @@ -127,6 +147,6 @@ class SymbolicDimMgr { std::unordered_map, SymProductHasher>; SymbolicDimProductMap productEqualityMap_; + bool productEqualityMapUpdated_ = true; }; - } // namespace ir diff --git a/test/cpp/ir/shape_dialect/symbolic_op_test.cc b/test/cpp/ir/shape_dialect/symbolic_op_test.cc index 7b0751d17ac..138e5e5b0d8 100644 --- a/test/cpp/ir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/ir/shape_dialect/symbolic_op_test.cc @@ -93,7 +93,10 @@ TEST(assist_struct_test, symbolic_dim_table) { EXPECT_FALSE(symbolTable.lookup("S1")); } -TEST(assist_struct_test, symbolic_dim_mgr) { +TEST(assist_struct_test, symbolic_dim_mgr_simple) { + /******************************************************/ + /* Mgr simple version, only SymbolicDim related func. */ + /******************************************************/ ir::IrContext *ctx = ir::IrContext::Instance(); ir::Program program(ctx); ctx->GetOrRegisterDialect(); @@ -141,6 +144,175 @@ TEST(assist_struct_test, symbolic_dim_mgr) { EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10)); } +TEST(assist_struct_test, symbolic_dim_mgr_complex) { + /***************************************************************/ + /* Mgr with constraintOp, and SymbolicDimProduct related func. */ + /***************************************************************/ + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + + ir::dialect::SymbolicDim symDimS0 = builder.Build( + "S0", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS1 = builder.Build( + "S1", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS2 = builder.Build( + "S2", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS3 = builder.Build( + "S3", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS4 = builder.Build( + "S4", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS5 = builder.Build( + "S5", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS6 = builder.Build( + "S6", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS7 = builder.Build( + "S7", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS8 = builder.Build( + "S8", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS9 = builder.Build( + "S9", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS10 = builder.Build( + "S10", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS11 = builder.Build( + "S11", -100000, false, false, true, true); + ir::dialect::SymbolicDim symDimS12 = builder.Build( + "S12", -100000, false, false, true, false); + ir::dialect::SymbolicDim symDimC10 = builder.Build( + "C10", 10, true, false, true, true); + ir::dialect::SymbolicDim symDimC20 = builder.Build( + "C20", 20, true, false, true, true); + + ir::OpResult dimOpS0 = builder.Build("S0").out(); + ir::OpResult dimOpS1 = builder.Build("S1").out(); + ir::OpResult dimOpS2 = builder.Build("S2").out(); + ir::OpResult dimOpS3 = builder.Build("S3").out(); + ir::OpResult dimOpS4 = builder.Build("S4").out(); + ir::OpResult dimOpS5 = builder.Build("S5").out(); + ir::OpResult dimOpS6 = builder.Build("S6").out(); + ir::OpResult dimOpS7 = builder.Build("S7").out(); + ir::OpResult dimOpS8 = builder.Build("S8").out(); + ir::OpResult dimOpS9 = builder.Build("S9").out(); + ir::OpResult dimOpS10 = builder.Build("S10").out(); + ir::OpResult dimOpS11 = builder.Build("S11").out(); + ir::OpResult dimOpC10 = builder.Build("C10").out(); + ir::OpResult dimOpC20 = builder.Build("C20").out(); + ir::OpResult constant = + builder + .Build(ir::Int32Attribute::get(ctx, 2), + ir::Int32Type::get(ctx)) + ->result(0); + + // Mark S1 == S2. + builder.Build( + 2, 2, std::vector{constant, dimOpS1, dimOpS2, constant}); + // Mark S0 * S1 == S2 * S3, For check S0 == S3. + builder.Build( + 2, 2, std::vector{dimOpS0, dimOpS1, dimOpS2, dimOpS3}); + // Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5. + builder.Build( + 3, + 3, + std::vector{ + dimOpS4, dimOpS0, dimOpS1, dimOpS2, dimOpS3, dimOpS5}); + // For check S6 == C10 * C20. + builder.Build( + 1, 2, std::vector{dimOpS6, dimOpC10, dimOpC20}); + // Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7. + builder.Build( + 3, + 3, + std::vector{ + dimOpC10, dimOpS0, dimOpS1, dimOpS2, dimOpS3, dimOpS7}); + + // Mark S8 * S9 == S10 * S11, for unsimplify product case + builder.Build( + 2, 2, std::vector{dimOpS8, dimOpS9, dimOpS10, dimOpS11}); + + ir::SymbolicDimMgr symDimMgr(program.module_op()); + + symDimMgr.load(); + + // For check indirect equality: S1 * S4 == S2 * S5 + ir::SymbolicDimProduct symDimProductLhs; + ir::SymbolicDimProduct symDimProductRhs; + + symDimProductLhs.symbols.push_back(symDimS1); + symDimProductLhs.symbols.push_back(symDimS4); + + symDimProductRhs.symbols.push_back(symDimS2); + symDimProductRhs.symbols.push_back(symDimS5); + + // For uncompletely simplied product check: S8 * S9 * S12 == S10 * S11 * S12 + ir::SymbolicDimProduct symDimProductLhs_; + ir::SymbolicDimProduct symDimProductRhs_; + + symDimProductLhs_.symbols.push_back(symDimS8); + symDimProductLhs_.symbols.push_back(symDimS9); + symDimProductLhs_.symbols.push_back(symDimS12); + + symDimProductRhs_.symbols.push_back(symDimS10); + symDimProductRhs_.symbols.push_back(symDimS11); + symDimProductRhs_.symbols.push_back(symDimS12); + + // For check simplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = + // 10} + ir::SymbolicDimProduct symDimProductS7; + symDimProductS7.symbols.push_back(symDimS7); + ir::SymbolicDimProduct simplifiedProductS7 = + symDimMgr.simplifySymbolicDimProduct(symDimProductS7); + + // For check simplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z + ir::SymbolicDimProduct symDimProductPairLhs; + ir::SymbolicDimProduct symDimProductPairRhs; + ir::SymbolicDimProduct newLhs, newRhs; + symDimProductPairLhs.symbols.push_back(symDimS4); + symDimProductPairLhs.symbols.push_back(symDimS1); + symDimProductPairLhs.symbols.push_back(symDimS2); + symDimProductPairRhs.symbols.push_back(symDimS1); + symDimProductPairRhs.symbols.push_back(symDimS2); + symDimProductPairRhs.symbols.push_back(symDimS3); + + std::tie(newLhs, newRhs) = symDimMgr.simplifySymbolicDimProductPair( + symDimProductPairLhs, symDimProductPairRhs); + + // For check symbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor + // = 2 Sym = {S4}} + ir::SymbolicDimProduct symDimProductDivLhs; + ir::SymbolicDimProduct symDimProductDivRhs; + symDimProductDivLhs.symbols.push_back(symDimS4); + symDimProductDivLhs.symbols.push_back(symDimS1); + symDimProductDivLhs.symbols.push_back(symDimC20); + symDimProductDivRhs.symbols.push_back(symDimS1); + symDimProductDivRhs.symbols.push_back(symDimC10); + + ir::SymbolicDimProduct *divRes = symDimMgr.symbolicDimProductDivide( + symDimProductDivLhs, symDimProductDivRhs); + + EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS1, symDimS2)); + EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS3)); + EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS4, symDimS5)); + EXPECT_EQ(symDimS6.getValue(), 200); + EXPECT_EQ(symDimMgr.symbolTable().lookup("C20"), + symDimC20); + EXPECT_EQ(symDimS7.getValue(), symDimC10.getValue()); + EXPECT_EQ(simplifiedProductS7.factor, 10); + EXPECT_EQ(simplifiedProductS7.symbols.size(), static_cast(0)); + EXPECT_EQ(newLhs.symbols.size(), static_cast(1)); + EXPECT_EQ(newRhs.symbols.size(), static_cast(1)); + EXPECT_EQ(newLhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS4)); + EXPECT_EQ(newRhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS3)); + EXPECT_EQ(divRes->factor, 2); + EXPECT_EQ(divRes->symbols.size(), static_cast(1)); + EXPECT_EQ(divRes->symbols[0], symDimMgr.getRootSymbolicDim(symDimS4)); + EXPECT_TRUE( + symDimMgr.isSymbolicDimProductEqual(symDimProductLhs, symDimProductRhs)); + EXPECT_TRUE(symDimMgr.isSymbolicDimProductEqual(symDimProductLhs_, + symDimProductRhs_)); +} + TEST(assist_struct_test, dim) { ir::IrContext *ctx = ir::IrContext::Instance(); ir::Program program(ctx); -- GitLab