diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index d07331e76f3af3dfb919b3924f86a143da1ce194..4b25921c8bb3cff3f7fc5ea9373b4c6becdb397f 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -483,6 +483,7 @@ class OpInfoParser: output_type_map = { 'Tensor': 'paddle::dialect::DenseTensorType', 'Tensor[]': 'ir::VectorType', + 'SelectedRows': 'paddle::dialect::SelectedRowsType', } type_list = [] for output_info in self.op_yaml_item['outputs']: diff --git a/paddle/fluid/ir/dialect/pd_dialect.cc b/paddle/fluid/ir/dialect/pd_dialect.cc index e365758e81ba1a5876137a53bba24223055bb29c..bc78e6ac5a4439852f665d432009f6a1e6f859b6 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/pd_dialect.cc @@ -92,6 +92,7 @@ PaddleDialect::PaddleDialect(ir::IrContext *context) void PaddleDialect::initialize() { RegisterTypes(); + RegisterTypes(); RegisterAttributeslod_; } const size_t& DenseTensorType::offset() const { return storage()->offset_; } +const ir::Type& SelectedRowsType::dtype() const { return storage()->dtype_; } + +const phi::DDim& SelectedRowsType::dims() const { return storage()->dims_; } + +const phi::DataLayout& SelectedRowsType::data_layout() const { + return storage()->layout_; +} + +const phi::LoD& SelectedRowsType::lod() const { return storage()->lod_; } + +const size_t& SelectedRowsType::offset() const { return storage()->offset_; } + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) diff --git a/paddle/fluid/ir/dialect/pd_type.h b/paddle/fluid/ir/dialect/pd_type.h index c0c45ebbecc29902a0664218b8c0cc12ae38ad51..249fc018690196133c528d0dc3a7c54e93a46ec5 100644 --- a/paddle/fluid/ir/dialect/pd_type.h +++ b/paddle/fluid/ir/dialect/pd_type.h @@ -39,7 +39,25 @@ class DenseTensorType : public ir::Type { const size_t &offset() const; }; +class SelectedRowsType : public ir::Type { + public: + using Type::Type; + + DECLARE_TYPE_UTILITY_FUNCTOR(SelectedRowsType, SelectedRowsTypeStorage); + + const ir::Type &dtype() const; + + const phi::DDim &dims() const; + + const phi::DataLayout &data_layout() const; + + const phi::LoD &lod() const; + + const size_t &offset() const; +}; + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) diff --git a/paddle/fluid/ir/dialect/pd_type_storage.h b/paddle/fluid/ir/dialect/pd_type_storage.h index 11aae78fb9b5283532af2803eb6317b99577d3e1..fcfff1db5ae855d76ec69ea99c69c447a8d99fbb 100644 --- a/paddle/fluid/ir/dialect/pd_type_storage.h +++ b/paddle/fluid/ir/dialect/pd_type_storage.h @@ -128,5 +128,86 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { size_t offset_; }; +struct SelectedRowsTypeStorage : public ir::TypeStorage { + using DataLayout = phi::DataLayout; + using Dim = phi::DDim; + using LoD = std::vector>; + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = + std::tuple; + + SelectedRowsTypeStorage(const ir::Type& dtype, + const phi::DDim& dims, + const phi::DataLayout& layout, + const phi::LoD& lod, + size_t offset) + : dtype_(dtype), + dims_(dims), + layout_(layout), + lod_(lod), + offset_(offset) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static SelectedRowsTypeStorage* Construct(const ParamKey& key) { + return new SelectedRowsTypeStorage(std::get<0>(key), + std::get<1>(key), + std::get<2>(key), + std::get<3>(key), + std::get<4>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + std::size_t hash_value = 317; + // hash dtype + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<0>(key))); + // hash dims + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<1>(key))); + // hash layout + hash_value = ir::hash_combine( + hash_value, + std::hash::type>()( + static_cast::type>( + std::get<2>(key)))); + // hash lod + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<3>(key))); + // hash offset + hash_value = + ir::hash_combine(hash_value, std::hash()(std::get<4>(key))); + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key; + } + + ParamKey GetAsKey() const { + return ParamKey(dtype_, dims_, layout_, lod_, offset_); + } + + /// + /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, + /// layout, lod, offset. + /// + ir::Type dtype_; + phi::DDim dims_; + phi::DataLayout layout_; + phi::LoD lod_; + size_t offset_; +}; + } // namespace dialect } // namespace paddle diff --git a/test/cpp/ir/core/type_test.cc b/test/cpp/ir/core/type_test.cc index f25fa3c82428e24eec5de4137bdc7432866db8ba..a748e1d5db88b8bb618208e4f038de260aba57e4 100644 --- a/test/cpp/ir/core/type_test.cc +++ b/test/cpp/ir/core/type_test.cc @@ -15,6 +15,7 @@ #include #include +#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" @@ -229,6 +230,23 @@ TEST(type_test, custom_type_dialect) { EXPECT_EQ(dialect_integer1, dialect_integer2); } +TEST(type_test, pd_dialect) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Type fp32_dtype = ir::Float32Type::get(ctx); + phi::DDim dims = {2, 2}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + paddle::dialect::SelectedRowsType select_rows_dtype = + paddle::dialect::SelectedRowsType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset); + EXPECT_EQ(select_rows_dtype.dtype().isa(), true); + EXPECT_EQ(select_rows_dtype.dims(), dims); + EXPECT_EQ(select_rows_dtype.data_layout(), data_layout); + EXPECT_EQ(select_rows_dtype.lod(), lod); + EXPECT_EQ(select_rows_dtype.offset(), offset); +} + namespace TestNamespace { class TestClass {}; } // namespace TestNamespace