未验证 提交 589588f3 编写于 作者: L liuruyan 提交者: GitHub

Add dimOp, tieProductEqualOp. access constraint_func in SymbolTable. Lowing...

Add dimOp, tieProductEqualOp. access constraint_func in SymbolTable. Lowing DenseTensorType. (#56615)

* add symbolicDimProduct & symbolicDimMgr without method shape_constraint related.

* add pd_type.cc to ir_shape CMakeLists.

* add dimOp, tieProductEqualOp. access constraint_func in SymbolTable.

* put DenseTensorType into builtin_type.
上级 39e8b023
......@@ -16,18 +16,6 @@
namespace paddle {
namespace dialect {
const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
const phi::DDim& DenseTensorType::dims() const { return storage()->dims_; }
const phi::DataLayout& DenseTensorType::data_layout() const {
return storage()->layout_;
}
const phi::LoD& DenseTensorType::lod() const { return storage()->lod_; }
const size_t& DenseTensorType::offset() const { return storage()->offset_; }
const ir::Type& SelectedRowsType::dtype() const { return storage()->dtype_; }
const phi::DDim& SelectedRowsType::dims() const { return storage()->dims_; }
......@@ -43,5 +31,4 @@ const size_t& SelectedRowsType::offset() const { return storage()->offset_; }
} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType)
......@@ -15,30 +15,12 @@
#pragma once
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/type.h"
namespace paddle {
namespace dialect {
///
/// \brief Define built-in parametric types.
///
class DenseTensorType : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage);
const ir::Type &dtype() const;
const phi::DDim &dims() const;
const phi::DataLayout &data_layout() const;
const phi::LoD &lod() const;
const size_t &offset() const;
};
using DenseTensorType = ir::DenseTensorType;
class SelectedRowsType : public ir::Type {
public:
using Type::Type;
......@@ -59,5 +41,4 @@ class SelectedRowsType : public ir::Type {
} // namespace dialect
} // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType)
......@@ -16,117 +16,15 @@
#include <type_traits>
#include "paddle/ir/core/builtin_type_storage.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/type_base.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/core/tensor_meta.h"
namespace std {
///
/// \brief Enable hashing std::vector<T> instances.
///
template <typename T>
struct hash<std::vector<T>> {
std::size_t operator()(const std::vector<T>& dim) const {
std::size_t seed = 0;
for (size_t i = 0; i < dim.size(); ++i) {
seed ^= std::hash<T>()(dim[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
} // namespace std
namespace paddle {
namespace dialect {
///
/// \brief Define Parametric TypeStorage for DenseTensorType.
///
/// NOTE(zhangbo9674): The derived TypeStorage class needs to implement the
/// following methods: (1)declare ParamKey, (2)define Construction method,
/// (3)define HashValue method, (4)overload operator==.
///
struct DenseTensorTypeStorage : public ir::TypeStorage {
using DataLayout = phi::DataLayout;
using Dim = phi::DDim;
using LoD = std::vector<std::vector<size_t>>;
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey =
std::tuple<ir::Type, phi::DDim, phi::DataLayout, phi::LoD, size_t>;
DenseTensorTypeStorage(const ir::Type& dtype,
const phi::DDim& dims,
const phi::DataLayout& layout,
const phi::LoD& lod,
size_t offset)
: dtype_(dtype),
dims_(dims),
layout_(layout),
lod_(lod),
offset_(offset) {}
///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static DenseTensorTypeStorage* Construct(const ParamKey& key) {
return new DenseTensorTypeStorage(std::get<0>(key),
std::get<1>(key),
std::get<2>(key),
std::get<3>(key),
std::get<4>(key));
}
///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
std::size_t hash_value = 0;
// hash dtype
hash_value =
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims
hash_value =
ir::hash_combine(hash_value, std::hash<phi::DDim>()(std::get<1>(key)));
// hash layout
hash_value = ir::hash_combine(
hash_value,
std::hash<std::underlying_type<phi::DataLayout>::type>()(
static_cast<std::underlying_type<phi::DataLayout>::type>(
std::get<2>(key))));
// hash lod
hash_value =
ir::hash_combine(hash_value, std::hash<phi::LoD>()(std::get<3>(key)));
// hash offset
hash_value =
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
return hash_value;
}
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key;
}
ParamKey GetAsKey() const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_);
}
///
/// \brief DenseTensorTypeStorage include five parameters: dims, dtype,
/// layout, lod, offset.
///
ir::Type dtype_;
phi::DDim dims_;
phi::DataLayout layout_;
phi::LoD lod_;
size_t offset_;
};
using DenseTensorTypeStorage = ir::DenseTensorTypeStorage;
struct SelectedRowsTypeStorage : public ir::TypeStorage {
using DataLayout = phi::DataLayout;
......
......@@ -3,4 +3,4 @@ set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir")
file(GLOB IR_SRCS "*.cc")
ir_library(ir_core SRCS ${IR_SRCS})
ir_library(ir_core SRCS ${IR_SRCS} DEPS ddim)
......@@ -17,6 +17,21 @@
namespace ir {
std::vector<Type> VectorType::data() const { return storage()->GetAsKey(); }
const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
const DenseTensorTypeStorage::Dim& DenseTensorType::dims() const {
return storage()->dims_;
}
const DenseTensorTypeStorage::DataLayout& DenseTensorType::data_layout() const {
return storage()->layout_;
}
const DenseTensorTypeStorage::LoD& DenseTensorType::lod() const {
return storage()->lod_;
}
const size_t& DenseTensorType::offset() const { return storage()->offset_; }
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::UInt8Type)
......@@ -33,3 +48,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::IndexType)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::DenseTensorType)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -53,6 +54,23 @@ class IR_API VectorType : public Type {
Type operator[](size_t index) const { return data()[index]; }
};
class DenseTensorType : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage);
const ir::Type &dtype() const;
const DenseTensorTypeStorage::Dim &dims() const;
const DenseTensorTypeStorage::DataLayout &data_layout() const;
const DenseTensorTypeStorage::LoD &lod() const;
const size_t &offset() const;
};
#define DECLARE_BUILTIN_TYPE(__name) \
class IR_API __name : public Type { \
public: \
......@@ -99,3 +117,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::IndexType)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::DenseTensorType)
......@@ -17,15 +17,122 @@
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/type_base.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
namespace std {
///
/// \brief Enable hashing std::vector<T> instances.
///
template <typename T>
struct hash<std::vector<T>> {
std::size_t operator()(const std::vector<T>& dim) const {
std::size_t seed = 0;
for (size_t i = 0; i < dim.size(); ++i) {
seed ^= std::hash<T>()(dim[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
} // namespace std
namespace ir {
///
/// \brief Define Parametric TypeStorage for DenseTensorType.
///
/// NOTE(zhangbo9674): The derived TypeStorage class needs to implement the
/// following methods: (1)declare ParamKey, (2)define Construction method,
/// (3)define HashValue method, (4)overload operator==.
///
struct DenseTensorTypeStorage : public ir::TypeStorage {
///
/// \brief Declare ParamKey according to parameter type.
///
using DataLayout = phi::DataLayout;
using Dim = phi::DDim;
using LoD = std::vector<std::vector<size_t>>;
using ParamKey = std::tuple<ir::Type, Dim, DataLayout, LoD, size_t>;
DenseTensorTypeStorage(const ir::Type& dtype,
const Dim& dims,
const DataLayout& layout,
const LoD& lod,
size_t offset)
: dtype_(dtype),
dims_(dims),
layout_(layout),
lod_(lod),
offset_(offset) {}
///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static DenseTensorTypeStorage* Construct(const ParamKey& key) {
return new DenseTensorTypeStorage(std::get<0>(key),
std::get<1>(key),
std::get<2>(key),
std::get<3>(key),
std::get<4>(key));
}
///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
std::size_t hash_value = 0;
// hash dtype
hash_value =
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims
hash_value =
ir::hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key)));
// hash layout
hash_value = ir::hash_combine(
hash_value,
std::hash<std::underlying_type<DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>(
std::get<2>(key))));
// hash lod
hash_value =
ir::hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key)));
// hash offset
hash_value =
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
return hash_value;
}
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key;
}
ParamKey GetAsKey() const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_);
}
///
/// \brief DenseTensorTypeStorage include five parameters: dims, dtype,
/// layout, lod, offset.
///
ir::Type dtype_;
Dim dims_;
DataLayout layout_;
LoD lod_;
size_t offset_;
};
struct VectorTypeStorage : public TypeStorage {
using ParamKey = std::vector<Type>;
explicit VectorTypeStorage(const ParamKey &key) {
data_ = reinterpret_cast<Type *>(malloc(key.size() * sizeof(Type)));
memcpy(reinterpret_cast<void *>(data_),
reinterpret_cast<const void *>(key.data()),
explicit VectorTypeStorage(const ParamKey& key) {
data_ = reinterpret_cast<Type*>(malloc(key.size() * sizeof(Type)));
memcpy(reinterpret_cast<void*>(data_),
reinterpret_cast<const void*>(key.data()),
key.size() * sizeof(Type));
size_ = key.size();
}
......@@ -36,14 +143,14 @@ struct VectorTypeStorage : public TypeStorage {
/// \brief Each derived TypeStorage must define a Construc method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static VectorTypeStorage *Construct(const ParamKey &key) {
static VectorTypeStorage* Construct(const ParamKey& key) {
return new VectorTypeStorage(key);
}
///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey &key) {
static std::size_t HashValue(const ParamKey& key) {
std::size_t hash_value = 0;
for (size_t i = 0; i < key.size(); ++i) {
hash_value = hash_combine(hash_value, std::hash<Type>()(key[i]));
......@@ -54,7 +161,7 @@ struct VectorTypeStorage : public TypeStorage {
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey &key) const {
bool operator==(const ParamKey& key) const {
if (key.size() != size_) {
return false;
}
......@@ -72,7 +179,7 @@ struct VectorTypeStorage : public TypeStorage {
/// \brief DenseTensorTypeStorage include five parameters: dims, dtype,
/// layout, lod, offset.
///
Type *data_;
Type* data_;
size_t size_;
};
......
file(GLOB_RECURSE SHAPE_SRCS "*.cc")
ir_library(
ir_shape
SRCS
${SHAPE_SRCS}
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc
DEPS
ddim
ir_core)
ir_library(ir_shape SRCS ${SHAPE_SRCS} DEPS ir_core)
......@@ -22,7 +22,9 @@ ShapeDialect::ShapeDialect(IrContext *context)
initialize();
}
void ShapeDialect::initialize() { RegisterOps<SymbolicDim>(); }
void ShapeDialect::initialize() {
RegisterOps<SymbolicDim, DimOp, TieProductEqualOp>();
}
} // namespace dialect
} // namespace ir
......
......@@ -14,6 +14,7 @@
#include "paddle/ir/dialect/shape/ir/shape_op.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
namespace ir {
namespace dialect {
......@@ -132,7 +133,65 @@ bool SymbolicDim::merge(SymbolicDim other) {
return true;
}
const char *DimOp::attributes_name[attributes_num] = {"name"}; // NOLINT
void DimOp::Build(Builder &builder,
OperationArgument &argument,
const std::string &name) {
ir::Attribute attr_name =
ir::StrAttribute::get(ir::IrContext::Instance(), name);
argument.AddAttribute("name", attr_name);
argument.output_types.emplace_back(
ir::IndexType::get(ir::IrContext::Instance()));
}
const std::string DimOp::getName() {
return attribute<ir::StrAttribute>("name").AsString();
}
void DimOp::setName(std::string attrName) {
operation()->set_attribute(
"name", ir::StrAttribute::get(ir::IrContext::Instance(), attrName));
}
const char *TieProductEqualOp::attributes_name[attributes_num] = {
"lhs_len", "rhs_len"}; // NOLINT
void TieProductEqualOp::Build(Builder &builder,
OperationArgument &argument,
int64_t lhs_len,
int64_t rhs_len,
const std::vector<ir::OpResult> &inputs) {
ir::Attribute attr_lhs_len =
ir::Int64Attribute::get(ir::IrContext::Instance(), lhs_len);
argument.AddAttribute("lhs_len", attr_lhs_len);
ir::Attribute attr_rhs_len =
ir::Int64Attribute::get(ir::IrContext::Instance(), rhs_len);
argument.AddAttribute("rhs_len", attr_rhs_len);
argument.inputs = inputs;
}
std::vector<ir::Value> TieProductEqualOp::getLhs() {
int64_t lhs_len = attribute<ir::Int64Attribute>("lhs_len").data();
std::vector<ir::Value> res;
for (uint32_t idx = 0; idx < lhs_len; idx++) {
res.push_back(operand_source(idx));
}
return res;
}
std::vector<ir::Value> TieProductEqualOp::getRhs() {
int64_t lhs_len = attribute<ir::Int64Attribute>("lhs_len").data();
int64_t rhs_len = attribute<ir::Int64Attribute>("rhs_len").data();
std::vector<ir::Value> res;
for (uint32_t idx = 0; idx < rhs_len; idx++) {
res.push_back(operand_source(lhs_len + idx));
}
return res;
}
} // namespace dialect
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::DimOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::TieProductEqualOp)
......@@ -57,7 +57,45 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {
void Verify() {}
};
class IR_API DimOp : public Op<DimOp> {
public:
using Op::Op;
static const char *name() { return "shape.dim"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::string &name);
const std::string getName();
void setName(std::string attrValue);
ir::OpResult out() { return result(0); }
void Verify() {}
};
class IR_API TieProductEqualOp : public Op<TieProductEqualOp> {
public:
using Op::Op;
static const char *name() { return "shape.tie_product_equal"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
// attr operand_segment_sizes
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
int64_t lhs_len,
int64_t rhs_len,
const std::vector<ir::OpResult> &inputs);
std::vector<ir::Value> getLhs();
std::vector<ir::Value> getRhs();
void Verify() {}
};
} // namespace dialect
} // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::DimOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::TieProductEqualOp);
......@@ -18,8 +18,8 @@
namespace ir {
bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) {
if (lhs.size() < 1 || lhs[0] != 'S' && lhs[0] != 'C') return lhs < rhs;
if (rhs.size() < 1 || rhs[0] != 'S' && rhs[0] != 'C') return lhs < rhs;
if (lhs.size() < 1 || (lhs[0] != 'S' && lhs[0] != 'C')) return lhs < rhs;
if (rhs.size() < 1 || (rhs[0] != 'S' && rhs[0] != 'C')) return lhs < rhs;
int64_t lhsIdx = 0, rhsIdx = 0;
try {
lhsIdx = stol(lhs.substr(1));
......@@ -30,18 +30,19 @@ bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) {
return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx);
}
ir::Operation* SymbolTable::lookup(const std::string& name) const {
auto it = symbolTableMap_.find(name);
return it != symbolTableMap_.end() ? it->second : nullptr;
}
const std::string SymbolTable::insert(ir::Operation* symbol) {
std::string name;
if (symbol->HasAttribute("sym_name")) {
if (symbol->name() == "shape.SymbolicDim") {
name = symbol->dyn_cast<SymbolicDim>().getSymName();
symbolTableMap_.insert({name, symbol});
}
// TODO(liujinnan): add constraint_func name branch.
symbolTableMap_.insert({name, symbol});
// TODO(liujinnan): add more constraint_func name branch.
if (symbol->name() == "shape.tie_product_equal") {
name = "tie_product_equal";
symbolFuncMap_[name].emplace_back(symbol);
}
return name;
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <functional>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include "paddle/ir/core/builtin_op.h"
......@@ -29,29 +30,51 @@ struct SymbolicDimProduct {
std::vector<SymbolicDim> symbols;
int64_t factor = 1;
bool empty() { return factor == 1 && symbols.empty(); }
};
inline bool operator==(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols;
}
friend inline bool operator==(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols;
}
inline bool operator!=(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return !(lhs == rhs);
}
friend inline bool operator!=(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return !(lhs == rhs);
}
};
class SymbolTable {
public:
explicit SymbolTable(ir::Operation* symbolTableOp)
: symbolTableOp_(symbolTableOp) {}
ir::Operation* lookup(const std::string& name) const;
template <typename T>
typename std::enable_if<std::is_same<T, SymbolicDim>::value,
SymbolicDim>::type
lookup(const std::string& name) const {
auto it = symbolTableMap_.find(name);
return it != symbolTableMap_.end() ? it->second->dyn_cast<SymbolicDim>()
: SymbolicDim(nullptr);
}
template <typename T>
typename std::enable_if<!std::is_same<T, SymbolicDim>::value,
std::vector<T>>::type
lookup(const std::string& name) const {
std::vector<T> res;
auto it = symbolFuncMap_.find(name);
if (it != symbolFuncMap_.end()) {
for (auto& p : it->second) {
res.push_back(p->dyn_cast<T>());
}
}
return res;
}
const std::string insert(Operation* symbol);
ir::Operation* getOp() const { return symbolTableOp_; }
private:
ir::Operation* symbolTableOp_;
std::unordered_map<std::string, ir::Operation*> symbolTableMap_;
std::unordered_map<std::string, std::vector<ir::Operation*>> symbolFuncMap_;
};
struct SymDimHasher {
......
......@@ -88,10 +88,9 @@ TEST(assist_struct_test, symbolic_dim_table) {
ir::SymbolTable symbolTable(program.module_op());
EXPECT_EQ(symbolTable.insert(symDim), "S0");
EXPECT_EQ(symbolTable.lookup("S0")->dyn_cast<ir::dialect::SymbolicDim>(),
symDim);
EXPECT_EQ(symbolTable.lookup("S1"), nullptr);
EXPECT_EQ(symbolTable.lookup<ir::dialect::SymbolicDim>("S0"), symDim);
EXPECT_EQ(symbolTable.getOp(), program.module_op());
EXPECT_FALSE(symbolTable.lookup<ir::dialect::SymbolicDim>("S1"));
}
TEST(assist_struct_test, symbolic_dim_mgr) {
......@@ -133,15 +132,63 @@ TEST(assist_struct_test, symbolic_dim_mgr) {
EXPECT_EQ(symDimC10.getValue(), 10);
EXPECT_EQ(symDimVec[0].getSymName(), "S2");
EXPECT_EQ(symDimVec[1].getSymName(), "C2");
EXPECT_EQ(symDimMgr.symbolTable()
.lookup("S0")
->dyn_cast<ir::dialect::SymbolicDim>(),
EXPECT_EQ(symDimMgr.symbolTable().lookup<ir::dialect::SymbolicDim>("S0"),
symDimS0);
EXPECT_EQ(symDimMgr.symbolTable()
.lookup("C10")
->dyn_cast<ir::dialect::SymbolicDim>(),
EXPECT_EQ(symDimMgr.symbolTable().lookup<ir::dialect::SymbolicDim>("C10"),
symDimC10);
EXPECT_EQ(symDimMgr.getRootSymbolicDim(symDimS1), symDimS0);
EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS1));
EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10));
}
TEST(assist_struct_test, dim) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
ir::dialect::DimOp dimOp = builder.Build<ir::dialect::DimOp>("S0");
ir::OpResult res = dimOp.out();
EXPECT_EQ(dimOp.getName(), "S0");
dimOp.setName("S1");
EXPECT_EQ(dimOp.getName(), "S1");
EXPECT_EQ(res.GetDefiningOp(), dimOp.operation());
EXPECT_EQ(res.type(), ir::IndexType::get(ctx));
}
TEST(assist_struct_test, tie_product_equal) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
ir::SymbolTable symbolTable(program.module_op());
ir::OpResult dimOp0 = builder.Build<ir::dialect::DimOp>("S0").out();
ir::OpResult dimOp1 = builder.Build<ir::dialect::DimOp>("S1").out();
ir::OpResult dimOp2 = builder.Build<ir::dialect::DimOp>("S2").out();
ir::OpResult dimOp3 = builder.Build<ir::dialect::DimOp>("S3").out();
ir::OpResult dimOp4 = builder.Build<ir::dialect::DimOp>("S4").out();
ir::dialect::TieProductEqualOp tie_product_equal =
builder.Build<ir::dialect::TieProductEqualOp>(
2,
3,
std::vector<ir::OpResult>{dimOp0, dimOp1, dimOp2, dimOp3, dimOp4});
std::vector<ir::Value> lhs = tie_product_equal.getLhs();
std::vector<ir::Value> rhs = tie_product_equal.getRhs();
std::vector<ir::Value> lhs_ref{dimOp0, dimOp1};
std::vector<ir::Value> rhs_ref{dimOp2, dimOp3, dimOp4};
EXPECT_EQ(symbolTable.insert(tie_product_equal), "tie_product_equal");
EXPECT_EQ(
symbolTable.lookup<ir::dialect::TieProductEqualOp>("tie_product_equal")
.size(),
static_cast<size_t>(1));
EXPECT_EQ(symbolTable.lookup<ir::dialect::TieProductEqualOp>(
"tie_product_equal")[0],
tie_product_equal);
EXPECT_EQ(lhs, lhs_ref);
EXPECT_EQ(rhs, rhs_ref);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册