未验证 提交 5c10be4f 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Add types and attributes to builtin and pd dialect (#53953)

* add types and attributes

* remove some const_cast

* refine code
上级 fa08a514
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/dialect/pd_attribute.h"
namespace paddle {
namespace dialect {
phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); }
phi::Scalar ScalarAttribute::data() const { return storage()->GetAsKey(); }
phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); }
phi::Place PlaceAttribute::data() const { return storage()->GetAsKey(); }
phi::DataLayout DataLayoutAttribute::data() const {
return storage()->GetAsKey();
}
} // namespace dialect
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/dialect/pd_attribute_storage.h"
#include "paddle/ir/attribute.h"
namespace paddle {
namespace dialect {
#define GET_PD_DIALECT_ATTRIBUTE_LIST \
IntArrayAttribute, ScalarAttribute, DataTypeAttribute, PlaceAttribute, \
DataLayoutAttribute
class IntArrayAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(IntArrayAttribute,
IntArrayAttributeStorage);
bool operator<(const IntArrayAttribute &right) const {
return storage() < right.storage();
}
phi::IntArray data() const;
};
class ScalarAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ScalarAttribute, ScalarAttributeStorage);
bool operator<(const ScalarAttribute &right) const {
return storage() < right.storage();
}
phi::Scalar data() const;
};
class DataTypeAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DataTypeAttribute,
DataTypeAttributeStorage);
bool operator<(const DataTypeAttribute &right) const {
return storage() < right.storage();
}
phi::DataType data() const;
};
class PlaceAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PlaceAttribute, PlaceAttributeStorage);
bool operator<(const PlaceAttribute &right) const {
return storage() < right.storage();
}
phi::Place data() const;
};
class DataLayoutAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DataLayoutAttribute,
DataLayoutAttributeStorage);
bool operator<(const DataLayoutAttribute &right) const {
return storage() < right.storage();
}
phi::DataLayout data() const;
};
} // namespace dialect
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/attribute.h"
#include "paddle/ir/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
namespace paddle {
namespace dialect {
struct IntArrayAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::IntArray;
explicit IntArrayAttributeStorage(const ParamKey &key) { data_ = key; }
static IntArrayAttributeStorage *Construct(ParamKey key) {
return new IntArrayAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
size_t hash_value = 0;
hash_value =
ir::hash_combine(hash_value, std::hash<bool>()(key.FromTensor()));
for (auto value : key.GetData()) {
hash_value = ir::hash_combine(hash_value, std::hash<int64_t>()(value));
}
return hash_value;
}
bool operator==(const ParamKey &key) const {
return (data_.GetData() == key.GetData()) &&
(data_.FromTensor() == key.FromTensor());
}
ParamKey GetAsKey() const { return ParamKey(data_); }
private:
phi::IntArray data_;
};
struct ScalarAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::Scalar;
explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; }
static ScalarAttributeStorage *Construct(ParamKey key) {
return new ScalarAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return ir::hash_combine(std::hash<std::string>()(key.ToString()),
std::hash<bool>()(key.FromTensor()));
}
bool operator==(const ParamKey &key) const { return data_ == key; }
ParamKey GetAsKey() const { return ParamKey(data_); }
private:
phi::Scalar data_;
};
struct DataTypeAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::DataType;
explicit DataTypeAttributeStorage(const ParamKey &key) { data_ = key; }
static DataTypeAttributeStorage *Construct(ParamKey key) {
return new DataTypeAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return std::hash<phi::DataType>()(key);
}
bool operator==(const ParamKey &key) const { return data_ == key; }
ParamKey GetAsKey() const { return data_; }
private:
phi::DataType data_;
};
struct PlaceAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::Place;
explicit PlaceAttributeStorage(const ParamKey &key) { data_ = key; }
static PlaceAttributeStorage *Construct(ParamKey key) {
return new PlaceAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) { return key.HashValue(); }
bool operator==(const ParamKey &key) const { return data_ == key; }
ParamKey GetAsKey() const { return data_; }
private:
phi::Place data_;
};
struct DataLayoutAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::DataLayout;
explicit DataLayoutAttributeStorage(const ParamKey &key) { data_ = key; }
static DataLayoutAttributeStorage *Construct(ParamKey key) {
return new DataLayoutAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return std::hash<phi::DataLayout>()(key);
}
bool operator==(const ParamKey &key) const { return data_ == key; }
ParamKey GetAsKey() const { return data_; }
private:
phi::DataLayout data_;
};
} // namespace dialect
} // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/dialect/pd_dialect.h" #include "paddle/fluid/dialect/pd_dialect.h"
#include "paddle/fluid/dialect/pd_attribute.h"
#include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/utils.h" #include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
...@@ -89,7 +90,8 @@ PaddleDialect::PaddleDialect(ir::IrContext* context) ...@@ -89,7 +90,8 @@ PaddleDialect::PaddleDialect(ir::IrContext* context)
} }
void PaddleDialect::initialize() { void PaddleDialect::initialize() {
RegisterTypes<GET_PADDLE_TYPE_LIST>(); RegisterTypes<GET_PD_DIALECT_TYPE_LIST>();
RegisterAttributes<GET_PD_DIALECT_ATTRIBUTE_LIST>();
RegisterInterfaces<ParameterConvertInterface>(); RegisterInterfaces<ParameterConvertInterface>();
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
#define GET_PADDLE_TYPE_LIST paddle::dialect::DenseTensorType #define GET_PD_DIALECT_TYPE_LIST paddle::dialect::DenseTensorType
/// ///
/// \brief Define built-in parametric types. /// \brief Define built-in parametric types.
......
...@@ -19,4 +19,18 @@ std::string StrAttribute::data() const { return storage()->GetAsKey(); } ...@@ -19,4 +19,18 @@ std::string StrAttribute::data() const { return storage()->GetAsKey(); }
uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); } uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); }
bool BoolAttribute::data() const { return storage()->GetAsKey(); }
float FloatAttribute::data() const { return storage()->GetAsKey(); }
double DoubleAttribute::data() const { return storage()->GetAsKey(); }
int32_t Int32_tAttribute::data() const { return storage()->GetAsKey(); }
int64_t Int64_tAttribute::data() const { return storage()->GetAsKey(); }
std::vector<Attribute> ArrayAttribute::data() const {
return storage()->GetAsKey();
}
} // namespace ir } // namespace ir
...@@ -22,9 +22,11 @@ namespace ir { ...@@ -22,9 +22,11 @@ namespace ir {
/// ///
/// \brief All built-in attributes. /// \brief All built-in attributes.
/// ///
#define GET_BUILT_IN_ATTRIBUTE_LIST ir::StrAttribute #define GET_BUILT_IN_ATTRIBUTE_LIST \
StrAttribute, BoolAttribute, FloatAttribute, DoubleAttribute, \
Int32_tAttribute, Int64_tAttribute, ArrayAttribute
class StrAttribute : public ir::Attribute { class StrAttribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
...@@ -39,13 +41,64 @@ class StrAttribute : public ir::Attribute { ...@@ -39,13 +41,64 @@ class StrAttribute : public ir::Attribute {
uint32_t size() const; uint32_t size() const;
}; };
} // namespace ir class BoolAttribute : public Attribute {
public:
using Attribute::Attribute;
namespace std { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(BoolAttribute, BoolAttributeStorage);
template <>
struct hash<ir::StrAttribute> { bool data() const;
std::size_t operator()(const ir::StrAttribute &obj) const { };
return std::hash<const ir::StrAttribute::Storage *>()(obj.storage());
} class FloatAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(FloatAttribute, FloatAttributeStorage);
float data() const;
};
class DoubleAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DoubleAttribute, DoubleAttributeStorage);
double data() const;
};
class Int32_tAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int32_tAttribute, Int32_tAttributeStorage);
int32_t data() const;
};
class Int64_tAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int64_tAttribute, Int64_tAttributeStorage);
int64_t data() const;
};
class ArrayAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage);
std::vector<Attribute> data() const;
size_t size() const { return data().size(); }
bool empty() const { return data().empty(); }
Attribute operator[](size_t index) const { return data()[index]; }
}; };
} // namespace std
} // namespace ir
...@@ -19,17 +19,41 @@ ...@@ -19,17 +19,41 @@
#include <type_traits> #include <type_traits>
#include "paddle/ir/attribute.h" #include "paddle/ir/attribute.h"
#include "paddle/ir/utils.h"
namespace ir { namespace ir {
#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \
struct concrete_storage : public ir::AttributeStorage { \
using ParamKey = bool; \
\
explicit concrete_storage(const ParamKey &key) { data_ = key; } \
\
static concrete_storage *Construct(ParamKey key) { \
return new concrete_storage(key); \
} \
\
static std::size_t HashValue(const ParamKey &key) { \
return std::hash<base_type>()(key); \
} \
\
bool operator==(const ParamKey &key) const { return data_ == key; } \
\
ParamKey GetAsKey() const { return data_; } \
\
private: \
ParamKey data_; \
};
/// ///
/// \brief Define Parameteric AttributeStorage for StrAttribute. /// \brief Define Parameteric AttributeStorage for StrAttribute.
/// ///
struct StrAttributeStorage : public ir::AttributeStorage { struct StrAttributeStorage : public AttributeStorage {
using ParamKey = std::string; using ParamKey = std::string;
explicit StrAttributeStorage(const ParamKey &key) { explicit StrAttributeStorage(const ParamKey &key) {
data_ = reinterpret_cast<char *>(malloc(key.size())); data_ = reinterpret_cast<char *>(malloc(key.size()));
memcpy(data_, const_cast<char *>(key.c_str()), key.size()); memcpy(data_, key.c_str(), key.size());
size_ = key.size(); size_ = key.size();
} }
...@@ -44,7 +68,7 @@ struct StrAttributeStorage : public ir::AttributeStorage { ...@@ -44,7 +68,7 @@ struct StrAttributeStorage : public ir::AttributeStorage {
} }
bool operator==(const ParamKey &key) const { bool operator==(const ParamKey &key) const {
return std::equal(data_, data_ + size_, const_cast<char *>(key.c_str())); return std::equal(data_, data_ + size_, key.c_str());
} }
ParamKey GetAsKey() const { return ParamKey(data_, size_); } ParamKey GetAsKey() const { return ParamKey(data_, size_); }
...@@ -54,4 +78,55 @@ struct StrAttributeStorage : public ir::AttributeStorage { ...@@ -54,4 +78,55 @@ struct StrAttributeStorage : public ir::AttributeStorage {
uint32_t size_; uint32_t size_;
}; };
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t);
struct ArrayAttributeStorage : public AttributeStorage {
using ParamKey = std::vector<Attribute>;
explicit ArrayAttributeStorage(const ParamKey &key) {
data_ =
reinterpret_cast<Attribute *>(malloc(sizeof(Attribute) * key.size()));
memcpy(reinterpret_cast<void *>(data_),
reinterpret_cast<const void *>(key.data()),
sizeof(Attribute) * key.size());
length_ = key.size();
}
~ArrayAttributeStorage() { free(reinterpret_cast<void *>(data_)); }
static ArrayAttributeStorage *Construct(ParamKey key) {
return new ArrayAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
std::size_t hash_value = 0;
for (size_t i = 0; i < key.size(); ++i) {
hash_value = hash_combine(hash_value, std::hash<Attribute>()(key[i]));
}
return hash_value;
}
bool operator==(const ParamKey &key) const {
if (key.size() != length_) {
return false;
}
for (size_t i = 0; i < length_; ++i) {
if (data_[i] != key[i]) {
return false;
}
}
return true;
}
ParamKey GetAsKey() const { return ParamKey(data_, data_ + length_); }
private:
Attribute *data_ = nullptr;
size_t length_ = 0;
};
} // namespace ir } // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builtin_type.h"
namespace ir {
std::vector<Type> VectorType::data() const { return storage()->GetAsKey(); }
} // namespace ir
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/ir/builtin_type_storage.h"
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
namespace ir { namespace ir {
...@@ -22,9 +23,9 @@ namespace ir { ...@@ -22,9 +23,9 @@ namespace ir {
/// The built-in Dialect will use this macro to quickly register all built-in /// The built-in Dialect will use this macro to quickly register all built-in
/// types. /// types.
/// ///
#define GET_BUILT_IN_TYPE_LIST \ #define GET_BUILT_IN_TYPE_LIST \
ir::Float16Type, ir::Float32Type, ir::Float64Type, ir::Int16Type, \ BFloat16Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, \
ir::Int32Type, ir::Int64Type Int32Type, Int64Type, BoolType, VectorType
/// ///
/// \brief Define built-in parameterless types. Please add the necessary /// \brief Define built-in parameterless types. Please add the necessary
...@@ -42,58 +43,96 @@ namespace ir { ...@@ -42,58 +43,96 @@ namespace ir {
/// Type fp32 = Float32Type::get(ctx); /// Type fp32 = Float32Type::get(ctx);
/// \endcode /// \endcode
/// ///
class Float16Type : public ir::Type { class BFloat16Type : public Type {
public: public:
using Type::Type; using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Float16Type, ir::TypeStorage); DECLARE_TYPE_UTILITY_FUNCTOR(BFloat16Type, TypeStorage);
};
class Float16Type : public Type {
public:
using Type::Type;
static Float16Type get(ir::IrContext *context); DECLARE_TYPE_UTILITY_FUNCTOR(Float16Type, TypeStorage);
static Float16Type get(IrContext *context);
}; };
class Float32Type : public ir::Type { class Float32Type : public Type {
public: public:
using Type::Type; using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Float32Type, ir::TypeStorage); DECLARE_TYPE_UTILITY_FUNCTOR(Float32Type, TypeStorage);
static Float32Type get(ir::IrContext *context); static Float32Type get(IrContext *context);
}; };
class Float64Type : public ir::Type { class Float64Type : public Type {
public: public:
using Type::Type; using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Float64Type, ir::TypeStorage); DECLARE_TYPE_UTILITY_FUNCTOR(Float64Type, TypeStorage);
static Float64Type get(ir::IrContext *context); static Float64Type get(IrContext *context);
}; };
class Int16Type : public ir::Type { class Int8Type : public Type {
public: public:
using Type::Type; using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int16Type, ir::TypeStorage); DECLARE_TYPE_UTILITY_FUNCTOR(Int8Type, TypeStorage);
};
static Int16Type get(ir::IrContext *context); class Int16Type : public Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int16Type, TypeStorage);
static Int16Type get(IrContext *context);
}; };
class Int32Type : public ir::Type { class Int32Type : public Type {
public: public:
using Type::Type; using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int32Type, ir::TypeStorage); DECLARE_TYPE_UTILITY_FUNCTOR(Int32Type, TypeStorage);
static Int32Type get(ir::IrContext *context); static Int32Type get(IrContext *context);
}; };
class Int64Type : public ir::Type { class Int64Type : public Type {
public: public:
using Type::Type; using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(Int64Type, ir::TypeStorage); DECLARE_TYPE_UTILITY_FUNCTOR(Int64Type, TypeStorage);
static Int64Type get(IrContext *context);
};
class BoolType : public Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(BoolType, TypeStorage);
static BoolType get(IrContext *context);
};
class VectorType : public Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(VectorType, VectorTypeStorage);
std::vector<Type> data() const;
size_t size() const { return data().size(); }
bool empty() const { return data().empty(); }
static Int64Type get(ir::IrContext *context); Type operator[](size_t index) const { return data()[index]; }
}; };
} // namespace ir } // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/type.h"
#include "paddle/ir/utils.h"
namespace ir {
struct VectorTypeStorage : public TypeStorage {
using ParamKey = std::vector<Type>;
explicit VectorTypeStorage(const ParamKey &key) {
data_ = reinterpret_cast<Type *>(malloc(key.size() * sizeof(Type)));
memcpy(reinterpret_cast<void *>(data_),
reinterpret_cast<const void *>(key.data()),
key.size() * sizeof(Type));
size_ = key.size();
}
~VectorTypeStorage() { free(data_); }
///
/// \brief Each derived TypeStorage must define a Construc method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static VectorTypeStorage *Construct(ParamKey key) {
return new VectorTypeStorage(key);
}
///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey &key) {
std::size_t hash_value = 0;
for (size_t i = 0; i < key.size(); ++i) {
hash_value = hash_combine(hash_value, std::hash<Type>()(key[i]));
}
return hash_value;
}
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey &key) const {
if (key.size() != size_) {
return false;
}
for (size_t i = 0; i < size_; ++i) {
if (data_[i] != key[i]) {
return false;
}
}
return true;
}
ParamKey GetAsKey() const { return ParamKey(data_, data_ + size_); }
///
/// \brief DenseTensorTypeStorage include five parameters: dims, dtype,
/// layout, lod, offset.
///
Type *data_;
size_t size_;
};
} // namespace ir
...@@ -128,6 +128,16 @@ TEST(type_test, built_in_type) { ...@@ -128,6 +128,16 @@ TEST(type_test, built_in_type) {
EXPECT_EQ(fp16_1.isa<ir::Float16Type>(), true); EXPECT_EQ(fp16_1.isa<ir::Float16Type>(), true);
EXPECT_EQ(fp16_1.isa<ir::Float32Type>(), false); EXPECT_EQ(fp16_1.isa<ir::Float32Type>(), false);
EXPECT_EQ(fp16_1.isa<ir::Type>(), true); EXPECT_EQ(fp16_1.isa<ir::Type>(), true);
// Test 3: Test VectorType
std::vector<ir::Type> vec_type = {int32_1, int64_1};
ir::Type vector_type = ir::VectorType::get(ctx, vec_type);
EXPECT_EQ(vector_type.isa<ir::VectorType>(), true);
EXPECT_EQ(vector_type.dyn_cast<ir::VectorType>().size() == 2, true);
EXPECT_EQ(vector_type.dyn_cast<ir::VectorType>()[0].isa<ir::Int32Type>(),
true);
EXPECT_EQ(vector_type.dyn_cast<ir::VectorType>()[1].isa<ir::Int64Type>(),
true);
} }
// Customize a parameterized TypeStorage IntegerTypeStorage. // Customize a parameterized TypeStorage IntegerTypeStorage.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册