未验证 提交 c6bd9fb8 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support SelectRowsType (#55041)

* add selectrows

* fix bug

* add ut

* refine code

* refien code
上级 af96d1e8
......@@ -483,6 +483,7 @@ class OpInfoParser:
output_type_map = {
'Tensor': 'paddle::dialect::DenseTensorType',
'Tensor[]': 'ir::VectorType<paddle::dialect::DenseTensorType>',
'SelectedRows': 'paddle::dialect::SelectedRowsType',
}
type_list = []
for output_info in self.op_yaml_item['outputs']:
......
......@@ -92,6 +92,7 @@ PaddleDialect::PaddleDialect(ir::IrContext *context)
void PaddleDialect::initialize() {
RegisterTypes<paddle::dialect::DenseTensorType>();
RegisterTypes<paddle::dialect::SelectedRowsType>();
RegisterAttributes<paddle::dialect::IntArrayAttribute,
paddle::dialect::DataTypeAttribute,
......
......@@ -161,3 +161,45 @@
- {typename: 'Tensor', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
- name: embedding_grad_sparse
inputs:
- typename: Tensor
name: x
optional: false
no_need_buffer: false
data_transform: {}
- typename: Tensor
name: weight
optional: false
no_need_buffer: false
data_transform: {}
- typename: Tensor
name: out_grad
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: int64_t, name: padding_idx, default_value: '-1'}
- {typename: bool, name: sparse, default_value: 'false'}
outputs:
- {typename: SelectedRows, name: weight_grad, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: UnchangedInferMeta
param: [weight]
kernel:
func: [embedding_grad_sparse]
param: [x, weight, out_grad, padding_idx, sparse]
backend: null
layout: null
data_type:
ordered: false
candidates: [weight]
to_complex_flag: [false]
dispatch: {embedding_grad_sparse: null}
force_backend: null
inplace: null
view: null
backward: null
......@@ -28,7 +28,20 @@ const phi::LoD& DenseTensorType::lod() const { return storage()->lod_; }
const size_t& DenseTensorType::offset() const { return storage()->offset_; }
const ir::Type& SelectedRowsType::dtype() const { return storage()->dtype_; }
const phi::DDim& SelectedRowsType::dims() const { return storage()->dims_; }
const phi::DataLayout& SelectedRowsType::data_layout() const {
return storage()->layout_;
}
const phi::LoD& SelectedRowsType::lod() const { return storage()->lod_; }
const size_t& SelectedRowsType::offset() const { return storage()->offset_; }
} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType)
......@@ -39,7 +39,25 @@ class DenseTensorType : public ir::Type {
const size_t &offset() const;
};
class SelectedRowsType : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(SelectedRowsType, SelectedRowsTypeStorage);
const ir::Type &dtype() const;
const phi::DDim &dims() const;
const phi::DataLayout &data_layout() const;
const phi::LoD &lod() const;
const size_t &offset() const;
};
} // namespace dialect
} // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorType)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType)
......@@ -128,5 +128,86 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
size_t offset_;
};
struct SelectedRowsTypeStorage : public ir::TypeStorage {
using DataLayout = phi::DataLayout;
using Dim = phi::DDim;
using LoD = std::vector<std::vector<size_t>>;
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey =
std::tuple<ir::Type, phi::DDim, phi::DataLayout, phi::LoD, size_t>;
SelectedRowsTypeStorage(const ir::Type& dtype,
const phi::DDim& dims,
const phi::DataLayout& layout,
const phi::LoD& lod,
size_t offset)
: dtype_(dtype),
dims_(dims),
layout_(layout),
lod_(lod),
offset_(offset) {}
///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static SelectedRowsTypeStorage* Construct(const ParamKey& key) {
return new SelectedRowsTypeStorage(std::get<0>(key),
std::get<1>(key),
std::get<2>(key),
std::get<3>(key),
std::get<4>(key));
}
///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
std::size_t hash_value = 317;
// hash dtype
hash_value =
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims
hash_value =
ir::hash_combine(hash_value, std::hash<phi::DDim>()(std::get<1>(key)));
// hash layout
hash_value = ir::hash_combine(
hash_value,
std::hash<std::underlying_type<phi::DataLayout>::type>()(
static_cast<std::underlying_type<phi::DataLayout>::type>(
std::get<2>(key))));
// hash lod
hash_value =
ir::hash_combine(hash_value, std::hash<phi::LoD>()(std::get<3>(key)));
// hash offset
hash_value =
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
return hash_value;
}
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_) == key;
}
ParamKey GetAsKey() const {
return ParamKey(dtype_, dims_, layout_, lod_, offset_);
}
///
/// \brief DenseTensorTypeStorage include five parameters: dims, dtype,
/// layout, lod, offset.
///
ir::Type dtype_;
phi::DDim dims_;
phi::DataLayout layout_;
phi::LoD lod_;
size_t offset_;
};
} // namespace dialect
} // namespace paddle
......@@ -15,6 +15,7 @@
#include <gtest/gtest.h>
#include <unordered_map>
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
......@@ -229,6 +230,23 @@ TEST(type_test, custom_type_dialect) {
EXPECT_EQ(dialect_integer1, dialect_integer2);
}
TEST(type_test, pd_dialect) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {2, 2};
phi::DataLayout data_layout = phi::DataLayout::NCHW;
phi::LoD lod = {{0, 1, 2}};
size_t offset = 0;
paddle::dialect::SelectedRowsType select_rows_dtype =
paddle::dialect::SelectedRowsType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
EXPECT_EQ(select_rows_dtype.dtype().isa<ir::Float32Type>(), true);
EXPECT_EQ(select_rows_dtype.dims(), dims);
EXPECT_EQ(select_rows_dtype.data_layout(), data_layout);
EXPECT_EQ(select_rows_dtype.lod(), lod);
EXPECT_EQ(select_rows_dtype.offset(), offset);
}
namespace TestNamespace {
class TestClass {};
} // namespace TestNamespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册