未验证 提交 14ede2b9 编写于 作者: L liuruyan 提交者: GitHub

Add constraint related func into SymDimMgr. (#56727)

* add symbolicDimProduct & symbolicDimMgr without method shape_constraint related.

* split ddim in phi, add a target ddim, used by pd_type.

* add pd_type.cc to ir_shape CMakeLists.

* add dimOp, tieProductEqualOp. access constraint_func in SymbolTable.

* put DenseTensorType into builtin_type.

* add constraint related Mgr func.

* move to out assert.
上级 75c4a24c
...@@ -112,6 +112,7 @@ bool SymbolicDim::merge(SymbolicDim other) { ...@@ -112,6 +112,7 @@ bool SymbolicDim::merge(SymbolicDim other) {
if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue()) if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue())
return false; return false;
if (isDynamic() && !other.isDynamic()) updateValue(other.getValue()); if (isDynamic() && !other.isDynamic()) updateValue(other.getValue());
if (!isDynamic() && other.isDynamic()) other.updateValue(getValue());
bool knownNonNegativeFlag = bool knownNonNegativeFlag =
getKnownNonNegative() || other.getKnownNonNegative(); getKnownNonNegative() || other.getKnownNonNegative();
......
...@@ -46,6 +46,154 @@ const std::string SymbolTable::insert(ir::Operation* symbol) { ...@@ -46,6 +46,154 @@ const std::string SymbolTable::insert(ir::Operation* symbol) {
return name; return name;
} }
bool SymbolicDimMgr::load() {
for (auto op_it = m_.block()->begin(); op_it != m_.block()->end(); op_it++) {
symbolTable_.insert(*op_it);
SymbolicDim op = (*op_it)->dyn_cast<SymbolicDim>();
if (!op) continue;
symbolDimUnionSet_[op] = op;
symbolNameSet_.insert(op.getSymName());
}
return loadShapeConstraintGraph();
}
bool SymbolicDimMgr::loadShapeConstraintGraph() {
// TODO(liujinnan): add more constraint function. currently, only support
// tie_product_equal.
auto constraint_vec =
symbolTable_.lookup<ir::dialect::TieProductEqualOp>("tie_product_equal");
if (!constraint_vec.size()) return true;
auto build_sym_product = [&](std::vector<ir::Value> range,
SymbolicDimProduct& product) {
for (Value v : range) {
auto definingOp = v.GetDefiningOp();
if (auto constOp = definingOp->dyn_cast<ir::ConstantOp>()) {
product.factor *= constOp.value().dyn_cast<ir::Int32Attribute>().data();
continue;
} else if (auto dimOp = definingOp->dyn_cast<ir::dialect::DimOp>()) {
auto sym = symbolTable_.lookup<SymbolicDim>(dimOp.getName());
if (!sym) return false;
product.symbols.push_back(sym);
continue;
}
return false;
}
return true;
};
for (auto op : constraint_vec) {
SymbolicDimProduct lhs, rhs;
if (!build_sym_product(op.getLhs(), lhs) ||
!build_sym_product(op.getRhs(), rhs) ||
!mapSymbolicDimProductEqual(lhs, rhs))
return false;
}
return true;
}
int64_t gcd(int64_t m, int64_t n) {
if (!m) return n;
if (!n) return m;
return (m < n) ? gcd(m, n % m) : gcd(m % n, n);
}
bool SymbolicDimMgr::mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
SymbolicDimProduct newLhs, newRhs;
std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs);
// early return for identity case.
if (newLhs == newRhs) return true;
if (newLhs.factor == newRhs.factor && newLhs.symbols.size() == 1 &&
newRhs.symbols.size() == 1) {
return mapSymbolicDimEqual(newLhs.symbols[0], newRhs.symbols[0]);
} else if (newLhs.symbols.size() == 0 && newRhs.symbols.size() == 1 &&
newRhs.factor == 1) {
return mapSymbolicDimEqual(newConstantSymbolicDim(newLhs.factor),
newRhs.symbols[0]);
} else if (newRhs.symbols.size() == 0 && newLhs.symbols.size() == 1 &&
newLhs.factor == 1) {
return mapSymbolicDimEqual(newConstantSymbolicDim(newRhs.factor),
newLhs.symbols[0]);
}
productEqualityMap_[newLhs][newRhs] = productEqualityMap_[newRhs][newLhs] =
true;
productEqualityMapUpdated_ = false;
return true;
}
std::pair<SymbolicDimProduct, SymbolicDimProduct>
SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x,
const SymbolicDimProduct& y) {
auto lhs = simplifySymbolicDimProduct(x);
auto rhs = simplifySymbolicDimProduct(y);
SymbolicDimProduct newLhs, newRhs;
int64_t gcdFactor = gcd(std::abs(lhs.factor), std::abs(rhs.factor));
if (!gcdFactor) return std::make_pair(std::move(newLhs), std::move(newRhs));
if (std::abs(lhs.factor) < std::abs(rhs.factor)) {
if (lhs.factor < 0) gcdFactor = -gcdFactor;
} else {
if (rhs.factor < 0) gcdFactor = -gcdFactor;
}
newLhs.factor = lhs.factor / gcdFactor;
newRhs.factor = rhs.factor / gcdFactor;
std::unordered_map<SymbolicDim, int, SymDimHasher> lhsSymbolMap;
std::unordered_map<SymbolicDim, int, SymDimHasher> rhsSymbolMap;
for (SymbolicDim op : lhs.symbols) ++lhsSymbolMap[op];
for (SymbolicDim op : rhs.symbols) ++rhsSymbolMap[op];
for (SymbolicDim op : lhs.symbols) {
auto it = rhsSymbolMap.find(op);
if (it != rhsSymbolMap.end() && op.getKnownNonSizeZero()) {
if (--it->second == 0) rhsSymbolMap.erase(it);
continue;
}
newLhs.symbols.push_back(op);
}
for (SymbolicDim op : rhs.symbols) {
auto it = lhsSymbolMap.find(op);
if (it != lhsSymbolMap.end() && op.getKnownNonSizeZero()) {
if (--it->second == 0) lhsSymbolMap.erase(it);
continue;
}
newRhs.symbols.push_back(op);
}
if (!newLhs.factor) newLhs.symbols.clear();
if (!newRhs.factor) newRhs.symbols.clear();
return std::make_pair(std::move(newLhs), std::move(newRhs));
}
SymbolicDimProduct SymbolicDimMgr::simplifySymbolicDimProduct(
const SymbolicDimProduct& x) {
std::vector<SymbolicDim> copied;
copied.reserve(x.symbols.size());
for (SymbolicDim op : x.symbols) copied.push_back(getRootSymbolicDim(op));
sort(copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) {
return compareSymbolicDimNames(lhs.getSymName(), rhs.getSymName());
});
SymbolicDimProduct newX;
newX.factor = x.factor;
for (SymbolicDim op : copied) {
if (!op.isDynamic()) {
newX.factor *= op.getValue();
} else {
newX.symbols.push_back(op);
}
}
return newX;
}
const std::string SymbolicDimMgr::getNextName() { const std::string SymbolicDimMgr::getNextName() {
std::string name; std::string name;
do { do {
...@@ -123,4 +271,154 @@ bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { ...@@ -123,4 +271,154 @@ bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) {
return true; return true;
} }
SymbolicDimProduct* SymbolicDimMgr::symbolicDimProductDivide(
const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) {
SymbolicDimProduct newLhs, newRhs;
std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs);
if (newLhs.factor == 0 || newRhs.factor == 0) return nullptr;
if (newLhs.factor % newRhs.factor != 0) return nullptr;
if (newLhs.symbols.size() < newRhs.symbols.size()) return nullptr;
SymbolicDimProduct* result = new SymbolicDimProduct();
result->factor = newLhs.factor / newRhs.factor;
std::unordered_map<SymbolicDim, int, SymDimHasher> symProcMap;
for (SymbolicDim sym : newRhs.symbols) ++symProcMap[sym];
for (SymbolicDim sym : newLhs.symbols) {
auto it = symProcMap.find(sym);
if (it == symProcMap.end()) {
result->symbols.push_back(sym);
continue;
}
if (--it->second == 0) {
symProcMap.erase(it);
continue;
}
}
if (!symProcMap.empty()) return nullptr;
return result;
}
bool SymbolicDimMgr::isMultipleOfKnownSymbolicDimProductEqualPair(
const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) {
for (auto& pairOutter : productEqualityMap_) {
const SymbolicDimProduct& x = pairOutter.first;
auto factorX = symbolicDimProductDivide(lhs, x);
if (!factorX) continue;
for (auto& pairInner : pairOutter.second) {
if (!pairInner.second) continue;
const SymbolicDimProduct& y = pairInner.first;
auto factorY = symbolicDimProductDivide(rhs, y);
if (!factorY || (*factorX) != (*factorY)) continue;
return true;
}
}
return false;
}
bool SymbolicDimMgr::updateProductEqualityMap() {
// early return if nothing is updated.
if (productEqualityMapUpdated_) return true;
SymbolicDimProductMap newMap;
std::unordered_set<SymbolicDimProduct, SymProductHasher> productSet;
for (auto& pairOutter : productEqualityMap_) {
const SymbolicDimProduct& x = pairOutter.first;
for (auto& pairInner : pairOutter.second) {
if (!pairInner.second) continue;
const SymbolicDimProduct& y = pairInner.first;
SymbolicDimProduct newX, newY;
std::tie(newX, newY) = simplifySymbolicDimProductPair(x, y);
if (newX == newY) continue;
newMap[newX][newY] = newMap[newY][newX] = true;
productSet.insert(newX);
productSet.insert(newY);
}
}
// hash function of SymbolicDimProduct is expensive, thus we map it to integer
// domain first.
std::unordered_map<const SymbolicDimProduct*, size_t> symProd2Idx;
std::vector<const SymbolicDimProduct*> idx2SymProd(productSet.size());
std::vector<size_t> idx2root(productSet.size());
for (auto& x : productSet) {
size_t idx = symProd2Idx.size();
symProd2Idx[&x] = idx;
idx2SymProd[idx] = &x;
idx2root[idx] = idx;
}
auto getRootIdx = [&](size_t root) {
std::vector<size_t> path;
while (idx2root[root] != root) {
path.push_back(root);
root = idx2root[root];
}
for (size_t idx : path) idx2root[idx] = root;
return root;
};
for (size_t x = 0; x < symProd2Idx.size(); ++x) {
auto& xProd = *idx2SymProd[x];
auto& rowMap = newMap[xProd];
size_t xRoot = getRootIdx(x);
for (size_t y = x; y < symProd2Idx.size(); ++y) {
auto& yProd = *idx2SymProd[y];
if (!rowMap[yProd]) continue;
idx2root[getRootIdx(y)] = xRoot;
}
}
for (size_t x = 0; x < symProd2Idx.size(); ++x)
for (size_t y = x; y < symProd2Idx.size(); ++y) {
if (getRootIdx(x) != getRootIdx(y)) continue;
auto& xSymProd = *idx2SymProd[x];
auto& ySymProd = *idx2SymProd[y];
newMap[xSymProd][ySymProd] = newMap[ySymProd][xSymProd] = true;
}
productEqualityMap_ = std::move(newMap);
for (auto& x : productSet)
for (auto& y : productSet) {
if (!productEqualityMap_[x][y]) continue;
productEqualityMap_[x][y] = productEqualityMap_[y][x] = false;
if (!isMultipleOfKnownSymbolicDimProductEqualPair(x, y)) {
productEqualityMap_[x][y] = productEqualityMap_[y][x] = true;
}
}
std::unordered_set<SymbolicDimProduct, SymProductHasher> toRemove;
for (auto& x : productSet) {
if (std::all_of(productSet.begin(),
productSet.end(),
[&](const SymbolicDimProduct& y) {
return !productEqualityMap_[x][y];
})) {
toRemove.insert(x);
}
}
for (auto& x : toRemove) {
productEqualityMap_.erase(x);
}
productEqualityMapUpdated_ = true;
return true;
}
bool SymbolicDimMgr::isSymbolicDimProductEqual(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
SymbolicDimProduct newLhs, newRhs;
std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs);
// early return for identity case.
if (newLhs == newRhs) return true;
IR_ENFORCE(updateProductEqualityMap(), "Update product equality map failed.");
return isMultipleOfKnownSymbolicDimProductEqualPair(newLhs, newRhs);
}
} // namespace ir } // namespace ir
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
#include "paddle/ir/dialect/shape/ir/shape_op.h" #include "paddle/ir/dialect/shape/ir/shape_op.h"
...@@ -45,7 +46,6 @@ class SymbolTable { ...@@ -45,7 +46,6 @@ class SymbolTable {
public: public:
explicit SymbolTable(ir::Operation* symbolTableOp) explicit SymbolTable(ir::Operation* symbolTableOp)
: symbolTableOp_(symbolTableOp) {} : symbolTableOp_(symbolTableOp) {}
template <typename T> template <typename T>
typename std::enable_if<std::is_same<T, SymbolicDim>::value, typename std::enable_if<std::is_same<T, SymbolicDim>::value,
SymbolicDim>::type SymbolicDim>::type
...@@ -97,6 +97,7 @@ struct SymProductHasher { ...@@ -97,6 +97,7 @@ struct SymProductHasher {
class SymbolicDimMgr { class SymbolicDimMgr {
public: public:
explicit SymbolicDimMgr(ir::ModuleOp m); explicit SymbolicDimMgr(ir::ModuleOp m);
bool load();
SymbolicDim newSymbolicDim(const std::string& name = {}); SymbolicDim newSymbolicDim(const std::string& name = {});
SymbolicDim newConstantSymbolicDim(int64_t val); SymbolicDim newConstantSymbolicDim(int64_t val);
std::vector<SymbolicDim> createSymbolicDimsForRankedValue(Value value); std::vector<SymbolicDim> createSymbolicDimsForRankedValue(Value value);
...@@ -104,9 +105,28 @@ class SymbolicDimMgr { ...@@ -104,9 +105,28 @@ class SymbolicDimMgr {
bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs);
SymbolTable& symbolTable() { return symbolTable_; } SymbolTable& symbolTable() { return symbolTable_; }
bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs);
SymbolicDimProduct simplifySymbolicDimProduct(const SymbolicDimProduct& x);
std::pair<SymbolicDimProduct, SymbolicDimProduct>
simplifySymbolicDimProductPair(const SymbolicDimProduct& x,
const SymbolicDimProduct& y);
SymbolicDimProduct* symbolicDimProductDivide(const SymbolicDimProduct& x,
const SymbolicDimProduct& y);
bool save(); // TODO(liujinnan): load constraint func
bool isSymbolicDimProductEqual(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs);
bool mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs);
private: private:
const std::string getNextName(); const std::string getNextName();
bool updateProductEqualityMap();
bool isMultipleOfKnownSymbolicDimProductEqualPair(
const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs);
bool saveShapeConstraintGraph(); // TODO(liujinnan): load & save
// shape_constraint_func
bool loadShapeConstraintGraph();
private: private:
ir::ModuleOp m_; ir::ModuleOp m_;
...@@ -127,6 +147,6 @@ class SymbolicDimMgr { ...@@ -127,6 +147,6 @@ class SymbolicDimMgr {
std::unordered_map<SymbolicDimProduct, bool, SymProductHasher>, std::unordered_map<SymbolicDimProduct, bool, SymProductHasher>,
SymProductHasher>; SymProductHasher>;
SymbolicDimProductMap productEqualityMap_; SymbolicDimProductMap productEqualityMap_;
bool productEqualityMapUpdated_ = true;
}; };
} // namespace ir } // namespace ir
...@@ -93,7 +93,10 @@ TEST(assist_struct_test, symbolic_dim_table) { ...@@ -93,7 +93,10 @@ TEST(assist_struct_test, symbolic_dim_table) {
EXPECT_FALSE(symbolTable.lookup<ir::dialect::SymbolicDim>("S1")); EXPECT_FALSE(symbolTable.lookup<ir::dialect::SymbolicDim>("S1"));
} }
TEST(assist_struct_test, symbolic_dim_mgr) { TEST(assist_struct_test, symbolic_dim_mgr_simple) {
/******************************************************/
/* Mgr simple version, only SymbolicDim related func. */
/******************************************************/
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx); ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>(); ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
...@@ -141,6 +144,175 @@ TEST(assist_struct_test, symbolic_dim_mgr) { ...@@ -141,6 +144,175 @@ TEST(assist_struct_test, symbolic_dim_mgr) {
EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10)); EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10));
} }
TEST(assist_struct_test, symbolic_dim_mgr_complex) {
/***************************************************************/
/* Mgr with constraintOp, and SymbolicDimProduct related func. */
/***************************************************************/
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
ir::dialect::SymbolicDim symDimS0 = builder.Build<ir::dialect::SymbolicDim>(
"S0", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS1 = builder.Build<ir::dialect::SymbolicDim>(
"S1", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS2 = builder.Build<ir::dialect::SymbolicDim>(
"S2", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS3 = builder.Build<ir::dialect::SymbolicDim>(
"S3", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS4 = builder.Build<ir::dialect::SymbolicDim>(
"S4", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS5 = builder.Build<ir::dialect::SymbolicDim>(
"S5", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS6 = builder.Build<ir::dialect::SymbolicDim>(
"S6", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS7 = builder.Build<ir::dialect::SymbolicDim>(
"S7", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS8 = builder.Build<ir::dialect::SymbolicDim>(
"S8", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS9 = builder.Build<ir::dialect::SymbolicDim>(
"S9", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS10 = builder.Build<ir::dialect::SymbolicDim>(
"S10", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS11 = builder.Build<ir::dialect::SymbolicDim>(
"S11", -100000, false, false, true, true);
ir::dialect::SymbolicDim symDimS12 = builder.Build<ir::dialect::SymbolicDim>(
"S12", -100000, false, false, true, false);
ir::dialect::SymbolicDim symDimC10 = builder.Build<ir::dialect::SymbolicDim>(
"C10", 10, true, false, true, true);
ir::dialect::SymbolicDim symDimC20 = builder.Build<ir::dialect::SymbolicDim>(
"C20", 20, true, false, true, true);
ir::OpResult dimOpS0 = builder.Build<ir::dialect::DimOp>("S0").out();
ir::OpResult dimOpS1 = builder.Build<ir::dialect::DimOp>("S1").out();
ir::OpResult dimOpS2 = builder.Build<ir::dialect::DimOp>("S2").out();
ir::OpResult dimOpS3 = builder.Build<ir::dialect::DimOp>("S3").out();
ir::OpResult dimOpS4 = builder.Build<ir::dialect::DimOp>("S4").out();
ir::OpResult dimOpS5 = builder.Build<ir::dialect::DimOp>("S5").out();
ir::OpResult dimOpS6 = builder.Build<ir::dialect::DimOp>("S6").out();
ir::OpResult dimOpS7 = builder.Build<ir::dialect::DimOp>("S7").out();
ir::OpResult dimOpS8 = builder.Build<ir::dialect::DimOp>("S8").out();
ir::OpResult dimOpS9 = builder.Build<ir::dialect::DimOp>("S9").out();
ir::OpResult dimOpS10 = builder.Build<ir::dialect::DimOp>("S10").out();
ir::OpResult dimOpS11 = builder.Build<ir::dialect::DimOp>("S11").out();
ir::OpResult dimOpC10 = builder.Build<ir::dialect::DimOp>("C10").out();
ir::OpResult dimOpC20 = builder.Build<ir::dialect::DimOp>("C20").out();
ir::OpResult constant =
builder
.Build<ir::ConstantOp>(ir::Int32Attribute::get(ctx, 2),
ir::Int32Type::get(ctx))
->result(0);
// Mark S1 == S2.
builder.Build<ir::dialect::TieProductEqualOp>(
2, 2, std::vector<ir::OpResult>{constant, dimOpS1, dimOpS2, constant});
// Mark S0 * S1 == S2 * S3, For check S0 == S3.
builder.Build<ir::dialect::TieProductEqualOp>(
2, 2, std::vector<ir::OpResult>{dimOpS0, dimOpS1, dimOpS2, dimOpS3});
// Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5.
builder.Build<ir::dialect::TieProductEqualOp>(
3,
3,
std::vector<ir::OpResult>{
dimOpS4, dimOpS0, dimOpS1, dimOpS2, dimOpS3, dimOpS5});
// For check S6 == C10 * C20.
builder.Build<ir::dialect::TieProductEqualOp>(
1, 2, std::vector<ir::OpResult>{dimOpS6, dimOpC10, dimOpC20});
// Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7.
builder.Build<ir::dialect::TieProductEqualOp>(
3,
3,
std::vector<ir::OpResult>{
dimOpC10, dimOpS0, dimOpS1, dimOpS2, dimOpS3, dimOpS7});
// Mark S8 * S9 == S10 * S11, for unsimplify product case
builder.Build<ir::dialect::TieProductEqualOp>(
2, 2, std::vector<ir::OpResult>{dimOpS8, dimOpS9, dimOpS10, dimOpS11});
ir::SymbolicDimMgr symDimMgr(program.module_op());
symDimMgr.load();
// For check indirect equality: S1 * S4 == S2 * S5
ir::SymbolicDimProduct symDimProductLhs;
ir::SymbolicDimProduct symDimProductRhs;
symDimProductLhs.symbols.push_back(symDimS1);
symDimProductLhs.symbols.push_back(symDimS4);
symDimProductRhs.symbols.push_back(symDimS2);
symDimProductRhs.symbols.push_back(symDimS5);
// For uncompletely simplied product check: S8 * S9 * S12 == S10 * S11 * S12
ir::SymbolicDimProduct symDimProductLhs_;
ir::SymbolicDimProduct symDimProductRhs_;
symDimProductLhs_.symbols.push_back(symDimS8);
symDimProductLhs_.symbols.push_back(symDimS9);
symDimProductLhs_.symbols.push_back(symDimS12);
symDimProductRhs_.symbols.push_back(symDimS10);
symDimProductRhs_.symbols.push_back(symDimS11);
symDimProductRhs_.symbols.push_back(symDimS12);
// For check simplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor =
// 10}
ir::SymbolicDimProduct symDimProductS7;
symDimProductS7.symbols.push_back(symDimS7);
ir::SymbolicDimProduct simplifiedProductS7 =
symDimMgr.simplifySymbolicDimProduct(symDimProductS7);
// For check simplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z
ir::SymbolicDimProduct symDimProductPairLhs;
ir::SymbolicDimProduct symDimProductPairRhs;
ir::SymbolicDimProduct newLhs, newRhs;
symDimProductPairLhs.symbols.push_back(symDimS4);
symDimProductPairLhs.symbols.push_back(symDimS1);
symDimProductPairLhs.symbols.push_back(symDimS2);
symDimProductPairRhs.symbols.push_back(symDimS1);
symDimProductPairRhs.symbols.push_back(symDimS2);
symDimProductPairRhs.symbols.push_back(symDimS3);
std::tie(newLhs, newRhs) = symDimMgr.simplifySymbolicDimProductPair(
symDimProductPairLhs, symDimProductPairRhs);
// For check symbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor
// = 2 Sym = {S4}}
ir::SymbolicDimProduct symDimProductDivLhs;
ir::SymbolicDimProduct symDimProductDivRhs;
symDimProductDivLhs.symbols.push_back(symDimS4);
symDimProductDivLhs.symbols.push_back(symDimS1);
symDimProductDivLhs.symbols.push_back(symDimC20);
symDimProductDivRhs.symbols.push_back(symDimS1);
symDimProductDivRhs.symbols.push_back(symDimC10);
ir::SymbolicDimProduct *divRes = symDimMgr.symbolicDimProductDivide(
symDimProductDivLhs, symDimProductDivRhs);
EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS1, symDimS2));
EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS3));
EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS4, symDimS5));
EXPECT_EQ(symDimS6.getValue(), 200);
EXPECT_EQ(symDimMgr.symbolTable().lookup<ir::dialect::SymbolicDim>("C20"),
symDimC20);
EXPECT_EQ(symDimS7.getValue(), symDimC10.getValue());
EXPECT_EQ(simplifiedProductS7.factor, 10);
EXPECT_EQ(simplifiedProductS7.symbols.size(), static_cast<size_t>(0));
EXPECT_EQ(newLhs.symbols.size(), static_cast<size_t>(1));
EXPECT_EQ(newRhs.symbols.size(), static_cast<size_t>(1));
EXPECT_EQ(newLhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS4));
EXPECT_EQ(newRhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS3));
EXPECT_EQ(divRes->factor, 2);
EXPECT_EQ(divRes->symbols.size(), static_cast<size_t>(1));
EXPECT_EQ(divRes->symbols[0], symDimMgr.getRootSymbolicDim(symDimS4));
EXPECT_TRUE(
symDimMgr.isSymbolicDimProductEqual(symDimProductLhs, symDimProductRhs));
EXPECT_TRUE(symDimMgr.isSymbolicDimProductEqual(symDimProductLhs_,
symDimProductRhs_));
}
TEST(assist_struct_test, dim) { TEST(assist_struct_test, dim) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx); ir::Program program(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册